Files
sglang/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py
2025-01-25 21:48:58 -08:00

86 lines
2.9 KiB
Python

from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits > 0,
logits / scaling_penalties,
logits * scaling_penalties,
)
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)