register apply_repetition_penalties_ in custom_op (#110)
* fix qwen2_vl for 0.11.0 * register apply_repetition_penalties_ in custom_op --------- Co-authored-by: luochencheng <luochencheng@baidu.com>
This commit is contained in:
@@ -1887,3 +1887,39 @@ def _fake_quant2d(
|
|||||||
|
|
||||||
|
|
||||||
quant2d.register_fake(_fake_quant2d)
|
quant2d.register_fake(_fake_quant2d)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# --------------- penalties -----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::apply_repetition_penalties_", mutates_args=())
|
||||||
|
def apply_repetition_penalties_(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_mask: torch.Tensor,
|
||||||
|
output_mask: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||||
|
1, logits.size(1))
|
||||||
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||||
|
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||||
|
1.0)
|
||||||
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||||
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||||
|
logits *= scaling
|
||||||
|
|
||||||
|
@impl("_C::apply_repetition_penalties_", "CUDA")
|
||||||
|
def apply_repetition_penalties_(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_mask: torch.Tensor,
|
||||||
|
output_mask: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||||
|
1, logits.size(1))
|
||||||
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||||
|
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||||
|
1.0)
|
||||||
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||||
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||||
|
logits *= scaling
|
||||||
|
|||||||
Reference in New Issue
Block a user