Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1,40 +1,34 @@
import abc
import dataclasses
import typing
from typing import List, Set, Type, Union
import torch
@dataclasses.dataclass
class _ReqLike:
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
origin_input_ids: List[int]
@dataclasses.dataclass
class _BatchLike:
reqs: typing.List[_ReqLike]
reqs: List[_ReqLike]
def batch_size(self):
return len(self.reqs)
class BatchedPenalizerOrchestrator:
batch: _BatchLike
device: str
vocab_size: int
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
def __init__(
self,
vocab_size: int,
batch: _BatchLike,
device: str,
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
Penalizers: Set[Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required
self.is_required = is_required
input_ids = [
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
for req in self.reqs()
]
if self.is_required:
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
self.cumulate_input_tokens(input_ids=input_ids)
def reqs(self):
return self.batch.reqs
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(
self,
input_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
input_ids (List[torch.Tensor]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(
self,
output_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
def cumulate_output_tokens(self, output_ids: torch.Tensor):
"""
Feed the output tokens to the penalizers.
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
output_ids (torch.Tensor): The output tokens.
"""
if not self.is_required:
return
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
def filter(
self,
indices_to_keep: typing.List[int],
indices_to_keep: List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_to_keep (List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
if not self.is_required:
@@ -174,32 +160,18 @@ class _TokenIDs:
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
orchestrator: BatchedPenalizerOrchestrator
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
cached_counts: torch.Tensor = None
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
token_ids: Union[torch.Tensor, List[torch.Tensor]],
):
self.orchestrator = orchestrator
if not isinstance(token_ids[0], torch.Tensor):
token_ids = [
torch.tensor(
data=ids, dtype=torch.int64, device=self.orchestrator.device
)
for ids in token_ids
]
self.token_ids = token_ids
self.cached_counts = None
def occurrence_count(self) -> torch.Tensor:
"""
@@ -213,30 +185,34 @@ class _TokenIDs:
token_ids = self.token_ids
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.unsqueeze(1)
# needs to be long to be used as index in scatter_add
if token_ids.dtype != torch.int64:
token_ids = token_ids.to(torch.int64)
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
if isinstance(token_ids, list):
# TODO: optimize this part
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
else:
# TODO: optimize this part. We do not need to create this big tensor every time.
# We can directly apply the results on the logits.
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
device=self.orchestrator.device,
)
self.cached_counts[
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
] = 1
return self.cached_counts
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer.
"""
orchestrator: BatchedPenalizerOrchestrator
_is_prepared: bool = False
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def is_prepared(self) -> bool:
return self._is_prepared
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
return self._apply(logits=logits)
def filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
if not self.is_prepared():
return
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
pass
@abc.abstractmethod
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
"""
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
"""

View File

@@ -1,8 +1,8 @@
import typing
from typing import List
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
)
def _teardown(self):
del self.frequency_penalties
del self.cumulated_frequency_penalties
self.frequency_penalties = None
self.cumulated_frequency_penalties = None
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
logits -= self.cumulated_frequency_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep

View File

@@ -1,8 +1,8 @@
import typing
from typing import List
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
)
def _teardown(self):
del self.min_new_tokens
del self.stop_token_penalties
del self.len_output_tokens
self.min_new_tokens = None
self.stop_token_penalties = None
self.len_output_tokens = None
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
logits[mask] += self.stop_token_penalties[mask]
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]

View File

@@ -1,8 +1,8 @@
import typing
from typing import List
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
)
def _teardown(self):
del self.presence_penalties
del self.cumulated_presence_penalties
self.presence_penalties = None
self.cumulated_presence_penalties = None
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
logits -= self.cumulated_presence_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep

View File

@@ -1,8 +1,8 @@
import typing
from typing import List
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
)
def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
logits * self.cumulated_repetition_penalties,
)
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep

View File

@@ -27,10 +27,10 @@ class SamplingBatchInfo:
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
grammars: Optional[List] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -211,25 +211,3 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
def copy(self):
return SamplingBatchInfo(
temperatures=self.temperatures,
top_ps=self.top_ps,
top_ks=self.top_ks,
min_ps=self.min_ps,
is_all_greedy=self.is_all_greedy,
need_min_p_sampling=self.need_min_p_sampling,
vocab_size=self.vocab_size,
device=self.device,
)
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))

View File

@@ -24,7 +24,6 @@ class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
@@ -34,6 +33,7 @@ class SamplingParams:
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
regex: Optional[str] = None,
n: int = 1,