register apply_repetition_penalties_ in custom_op (#110)

* fix qwen2_vl for 0.11.0

* register apply_repetition_penalties_ in custom_op

---------

Co-authored-by: luochencheng <luochencheng@baidu.com>
This commit is contained in:
roger-lcc
2026-01-13 20:22:14 +08:00
committed by GitHub
parent fb424acca7
commit 37cc307322

View File

@@ -1887,3 +1887,39 @@ def _fake_quant2d(
quant2d.register_fake(_fake_quant2d)
##################################################
# --------------- penalties -----------------
##################################################
@custom_op("_C::apply_repetition_penalties_", mutates_args=())
def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
@impl("_C::apply_repetition_penalties_", "CUDA")
def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling