register apply_repetition_penalties_ in custom_op (#110)
* fix qwen2_vl for 0.11.0 * register apply_repetition_penalties_ in custom_op --------- Co-authored-by: luochencheng <luochencheng@baidu.com>
This commit is contained in:
@@ -1887,3 +1887,39 @@ def _fake_quant2d(
|
||||
|
||||
|
||||
quant2d.register_fake(_fake_quant2d)
|
||||
|
||||
|
||||
##################################################
|
||||
# --------------- penalties -----------------
|
||||
##################################################
|
||||
@custom_op("_C::apply_repetition_penalties_", mutates_args=())
|
||||
def apply_repetition_penalties_(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor
|
||||
) -> None:
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, logits.size(1))
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
@impl("_C::apply_repetition_penalties_", "CUDA")
|
||||
def apply_repetition_penalties_(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor
|
||||
) -> None:
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, logits.size(1))
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
Reference in New Issue
Block a user