Files
sglang/python/sglang/srt/sampling/penaltylib/orchestrator.py

358 lines
11 KiB
Python

import abc
import dataclasses
import typing
import torch
@dataclasses.dataclass
class _ReqLike:
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
@dataclasses.dataclass
class _BatchLike:
reqs: typing.List[_ReqLike]
def batch_size(self):
return len(self.reqs)
class BatchedPenalizerOrchestrator:
batch: _BatchLike
device: str
vocab_size: int
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
def __init__(
self,
vocab_size: int,
batch: _BatchLike,
device: str,
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
for penalizer in self.penalizers.values():
penalizer.prepare_if_required()
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self):
return self.batch.reqs
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(
self,
input_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(
self,
output_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the output tokens to the penalizers.
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=token_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizers to the logits.
Note that it may apply the penalizers in-place.
Args:
logits (torch.Tensor): The logits to apply the penalizers to.
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)
return logits
def filter(
self,
indices_to_keep: typing.List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
empty_indices = len(indices_to_keep) == 0
for penalizer in self.penalizers.values():
if not penalizer.is_required() or empty_indices:
penalizer.teardown()
else:
# create tensor index only when it's needed
if indices_tensor_to_keep is None:
indices_tensor_to_keep = torch.tensor(
indices_to_keep, dtype=torch.int32, device=self.device
)
penalizer.filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
Merge the penalizers of another orchestrator into this one.
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
if self.vocab_size != their.vocab_size:
raise ValueError(
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
)
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
self.penalizers[Penalizer].merge(their_penalizer)
class _TokenIDs:
"""
A class that wraps token IDs to provide additional utility functions to penalizers.
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
orchestrator: BatchedPenalizerOrchestrator
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
cached_counts: torch.Tensor = None
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
self.orchestrator = orchestrator
if not isinstance(token_ids[0], torch.Tensor):
token_ids = [
torch.tensor(
data=ids, dtype=torch.int64, device=self.orchestrator.device
)
for ids in token_ids
]
self.token_ids = token_ids
def occurrence_count(self) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
Returns:
torch.Tensor: The occurrence count tensor.
"""
if self.cached_counts is not None:
return self.cached_counts
token_ids = self.token_ids
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.unsqueeze(1)
# needs to be long to be used as index in scatter_add
if token_ids.dtype != torch.int64:
token_ids = token_ids.to(torch.int64)
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
return self.cached_counts
class _BatchedPenalizer(abc.ABC):
"""
An abstract class for a batched penalizer.
"""
orchestrator: BatchedPenalizerOrchestrator
_is_prepared: bool = False
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
def is_prepared(self) -> bool:
return self._is_prepared
def is_required(self) -> bool:
return self._is_required()
def prepare(self):
if not self.is_prepared():
self._prepare()
self._is_prepared = True
def prepare_if_required(self):
if self.is_required():
self.prepare()
def teardown(self):
if self.is_prepared():
self._teardown()
self._is_prepared = False
def cumulate_input_tokens(self, input_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_input_tokens(input_ids=input_ids)
def cumulate_output_tokens(self, output_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_prepared():
return logits
return self._apply(logits=logits)
def filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
if not self.is_prepared():
return
self._filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
def merge(self, their: "_BatchedPenalizer"):
if not self.is_prepared() and not their.is_prepared():
return
self.prepare()
their.prepare()
self._merge(their)
@abc.abstractmethod
def _is_required(self) -> bool:
"""
Check if the penalizer is required to be prepared.
"""
pass
@abc.abstractmethod
def _prepare(self):
"""
Prepare the penalizer.
Usually, this is where the penalizer initializes its tensors.
"""
pass
@abc.abstractmethod
def _teardown(self):
"""
Tear down the penalizer.
Usually, this is where the penalizer frees its tensors.
"""
pass
@abc.abstractmethod
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
"""
Cumulate the input tokens.
Orchestrator will call this function to feed the input tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
"""
Cumulate the output tokens.
Orchestrator will call this function to feed the output tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizer to the logits.
Penalizers can modify the logits in-place if needed.
"""
pass
@abc.abstractmethod
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
"""
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
"""
pass
@abc.abstractmethod
def _merge(self, their: "_BatchedPenalizer"):
"""
Merge the penalizer with another penalizer.
"""
pass