Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1,40 +1,34 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import typing
|
||||
from typing import List, Set, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
||||
origin_input_ids: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: typing.List[_ReqLike]
|
||||
reqs: List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
batch: _BatchLike
|
||||
device: str
|
||||
vocab_size: int
|
||||
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
||||
Penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
|
||||
is_required = False
|
||||
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
input_ids = [
|
||||
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
||||
for req in self.reqs()
|
||||
]
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
self.cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(
|
||||
self,
|
||||
input_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
||||
input_ids (List[torch.Tensor]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(
|
||||
self,
|
||||
output_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||
output_ids (torch.Tensor): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: typing.List[int],
|
||||
indices_to_keep: List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||
indices_to_keep (List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
"""
|
||||
if not self.is_required:
|
||||
@@ -174,32 +160,18 @@ class _TokenIDs:
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
||||
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
||||
cached_counts: torch.Tensor = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
if not isinstance(token_ids[0], torch.Tensor):
|
||||
token_ids = [
|
||||
torch.tensor(
|
||||
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
||||
)
|
||||
for ids in token_ids
|
||||
]
|
||||
|
||||
self.token_ids = token_ids
|
||||
self.cached_counts = None
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
@@ -213,30 +185,34 @@ class _TokenIDs:
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
|
||||
# needs to be long to be used as index in scatter_add
|
||||
if token_ids.dtype != torch.int64:
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
if isinstance(token_ids, list):
|
||||
# TODO: optimize this part
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
else:
|
||||
# TODO: optimize this part. We do not need to create this big tensor every time.
|
||||
# We can directly apply the results on the logits.
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
self.cached_counts[
|
||||
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
||||
] = 1
|
||||
|
||||
return self.cached_counts
|
||||
|
||||
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
_is_prepared: bool = False
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import typing
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.frequency_penalties
|
||||
del self.cumulated_frequency_penalties
|
||||
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import typing
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.min_new_tokens
|
||||
del self.stop_token_penalties
|
||||
del self.len_output_tokens
|
||||
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import typing
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.presence_penalties
|
||||
del self.cumulated_presence_penalties
|
||||
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import typing
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.repetition_penalties
|
||||
del self.cumulated_repetition_penalties
|
||||
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
logits * self.cumulated_repetition_penalties,
|
||||
)
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
|
||||
@@ -27,10 +27,10 @@ class SamplingBatchInfo:
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
grammars: Optional[List] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
@@ -211,25 +211,3 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
return SamplingBatchInfo(
|
||||
temperatures=self.temperatures,
|
||||
top_ps=self.top_ps,
|
||||
top_ks=self.top_ks,
|
||||
min_ps=self.min_ps,
|
||||
is_all_greedy=self.is_all_greedy,
|
||||
need_min_p_sampling=self.need_min_p_sampling,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
value = getattr(self, item)
|
||||
setattr(self, item, value.to(device, non_blocking=True))
|
||||
|
||||
@@ -24,7 +24,6 @@ class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 128,
|
||||
min_new_tokens: int = 0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
@@ -34,6 +33,7 @@ class SamplingParams:
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
min_new_tokens: int = 0,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user