From 37cc307322c84a94e5afa1758beb3daa97d2583b Mon Sep 17 00:00:00 2001 From: roger-lcc <58332996+roger-lcc@users.noreply.github.com> Date: Tue, 13 Jan 2026 20:22:14 +0800 Subject: [PATCH] register apply_repetition_penalties_ in custom_op (#110) * fix qwen2_vl for 0.11.0 * register apply_repetition_penalties_ in custom_op --------- Co-authored-by: luochencheng --- vllm_kunlun/vllm_utils_wrapper.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index 02f980f..cac5cf4 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1887,3 +1887,39 @@ def _fake_quant2d( quant2d.register_fake(_fake_quant2d) + + +################################################## +# --------------- penalties ----------------- +################################################## +@custom_op("_C::apply_repetition_penalties_", mutates_args=()) +def apply_repetition_penalties_( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor +) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + +@impl("_C::apply_repetition_penalties_", "CUDA") +def apply_repetition_penalties_( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor +) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling