diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index 02f980f..cac5cf4 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1887,3 +1887,39 @@ def _fake_quant2d( quant2d.register_fake(_fake_quant2d) + + +################################################## +# --------------- penalties ----------------- +################################################## +@custom_op("_C::apply_repetition_penalties_", mutates_args=()) +def apply_repetition_penalties_( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor +) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + +@impl("_C::apply_repetition_penalties_", "CUDA") +def apply_repetition_penalties_( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor +) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling