Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
from .orchestrator import BatchedPenalizerOrchestrator
|
||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
||||
from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
|
||||
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
|
||||
|
||||
__all__ = [
|
||||
"BatchedFrequencyPenalizer",
|
||||
"BatchedMinNewTokensPenalizer",
|
||||
"BatchedPresencePenalizer",
|
||||
"BatchedRepetitionPenalizer",
|
||||
"BatchedPenalizerOrchestrator",
|
||||
]
|
||||
|
||||
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = torch.zeros(
|
||||
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
).unsqueeze_(1)
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.cumulated_frequency_penalties.scatter_add_(
|
||||
dim=1,
|
||||
index=output_ids.unsqueeze(1),
|
||||
src=self.frequency_penalties,
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits.sub_(self.cumulated_frequency_penalties)
|
||||
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.frequency_penalties = self.frequency_penalties[keep_indices]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
keep_indices
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||
"""
|
||||
|
||||
min_new_tokens: torch.Tensor = None
|
||||
stop_token_penalties: torch.Tensor = None
|
||||
len_output_tokens: torch.Tensor = None
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.stop_token_penalties = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
]
|
||||
|
||||
self.len_output_tokens = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), 1),
|
||||
size=(len(self.orchestrator.reqs()), 1),
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.len_output_tokens += 1
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
def _apply(self, logits: torch.Tensor):
|
||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.min_new_tokens = self.min_new_tokens[keep_indices]
|
||||
self.stop_token_penalties = self.stop_token_penalties[keep_indices]
|
||||
self.len_output_tokens = self.len_output_tokens[keep_indices]
|
||||
|
||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||
self.min_new_tokens = torch.cat(
|
||||
@@ -1,35 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
from typing import List, Set, Type, Union
|
||||
from typing import TYPE_CHECKING, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
batch: ScheduleBatch,
|
||||
penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
self.device = batch.device
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
input_ids = [
|
||||
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
||||
for req in self.reqs()
|
||||
]
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (List[torch.Tensor]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
|
||||
Args:
|
||||
output_ids (torch.Tensor): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
||||
penalizer.cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
penalizer.apply(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
def filter(self, keep_indices: torch.Tensor):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
if len(keep_indices) == 0:
|
||||
self.is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.teardown()
|
||||
return
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
tmp_is_required = penalizer.is_required()
|
||||
is_required = is_required or tmp_is_required
|
||||
if not tmp_is_required or empty_indices:
|
||||
penalizer.teardown()
|
||||
is_required |= tmp_is_required
|
||||
if tmp_is_required:
|
||||
penalizer.filter(keep_indices=keep_indices)
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
if indices_tensor_to_keep is None:
|
||||
indices_tensor_to_keep = torch.tensor(
|
||||
indices_to_keep, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
penalizer.filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
penalizer.teardown()
|
||||
self.is_required = is_required
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
|
||||
if not self.is_required and not their.is_required:
|
||||
return
|
||||
|
||||
self.is_required |= their.is_required
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
|
||||
self.penalizers[Penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _TokenIDs:
|
||||
"""
|
||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
self.token_ids = token_ids
|
||||
self.cached_counts = None
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The occurrence count tensor.
|
||||
"""
|
||||
if self.cached_counts is not None:
|
||||
return self.cached_counts
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, list):
|
||||
# TODO: optimize this part
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
else:
|
||||
# TODO: optimize this part. We do not need to create this big tensor every time.
|
||||
# We can directly apply the results on the logits.
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
self.cached_counts[
|
||||
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
||||
] = 1
|
||||
|
||||
return self.cached_counts
|
||||
self.is_required = True
|
||||
for penalizer, their_penalizer in their.penalizers.items():
|
||||
self.penalizers[penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _BatchedPenalizer(abc.ABC):
|
||||
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
|
||||
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
|
||||
return self._is_required()
|
||||
|
||||
def prepare(self):
|
||||
if not self.is_prepared():
|
||||
if not self._is_prepared:
|
||||
self._prepare()
|
||||
self._is_prepared = True
|
||||
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
if self._is_required():
|
||||
self.prepare()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
self._teardown()
|
||||
self._is_prepared = False
|
||||
self._is_prepared = False
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_prepared():
|
||||
return logits
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
if not self.is_prepared():
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
self._apply(logits=logits)
|
||||
|
||||
def filter(self, keep_indices: torch.Tensor):
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._filter(keep_indices=keep_indices)
|
||||
|
||||
def merge(self, their: "_BatchedPenalizer"):
|
||||
if not self.is_prepared() and not their.is_prepared():
|
||||
if not self._is_prepared and not their._is_prepared:
|
||||
return
|
||||
|
||||
self.prepare()
|
||||
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _teardown(self):
|
||||
"""
|
||||
Tear down the penalizer.
|
||||
Usually, this is where the penalizer frees its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the input tokens.
|
||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
"""
|
||||
Cumulate the output tokens.
|
||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
frequency_penalties: torch.Tensor = None
|
||||
cumulated_frequency_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_frequency_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.cumulated_frequency_penalties += (
|
||||
self.frequency_penalties * output_ids.occurrence_count()
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,74 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
presence_penalties: torch.Tensor = None
|
||||
cumulated_presence_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_presence_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def apply_scaling_penalties(logits, scaling_penalties):
|
||||
logits[:] = torch.where(
|
||||
logits > 0,
|
||||
logits / scaling_penalties,
|
||||
logits * scaling_penalties,
|
||||
)
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
||||
"""
|
||||
|
||||
repetition_penalties: torch.Tensor = None
|
||||
cumulated_repetition_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.repetition_penalty != 1.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.repetition_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_repetition_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
mask = input_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
||||
self.repetition_penalties = torch.cat(
|
||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
||||
)
|
||||
self.cumulated_repetition_penalties = torch.cat(
|
||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
||||
dim=0,
|
||||
)
|
||||
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = torch.zeros(
|
||||
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
).unsqueeze_(1)
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.cumulated_presence_penalties.scatter_(
|
||||
dim=1,
|
||||
index=output_ids.unsqueeze(1),
|
||||
src=self.presence_penalties,
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits.sub_(self.cumulated_presence_penalties)
|
||||
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.presence_penalties = self.presence_penalties[keep_indices]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
keep_indices
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -9,9 +9,6 @@ import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
||||
apply_scaling_penalties,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Batched sampling params
|
||||
# Basic batched sampling params
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# All requests use greedy sampling
|
||||
# Whether all requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool
|
||||
|
||||
# Bias Tensors
|
||||
# Masking tensors for grammar-guided structured outputs
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
|
||||
# An event used for overlap schedule
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
linear_penalties: Optional[torch.Tensor] = None
|
||||
scaling_penalties: Optional[torch.Tensor] = None
|
||||
linear_penalty: torch.Tensor = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
# Custom Parameters
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool = False
|
||||
# Custom parameters
|
||||
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||
|
||||
# Custom Logit Processor
|
||||
# Custom logit processor
|
||||
custom_logit_processor: Optional[
|
||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||
] = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||
):
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
|
||||
merged_custom_logit_processor = None
|
||||
custom_params = None
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we can choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
has_custom_logit_processor=has_custom_logit_processor,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
device=device,
|
||||
penalizer_orchestrator=penalizer_orchestrator,
|
||||
has_custom_logit_processor=has_custom_logit_processor,
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
device=device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
if enable_overlap_schedule:
|
||||
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
||||
# so it is kind of tricky to make it work with overlap scheduler.
|
||||
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
||||
# We will support them later.
|
||||
penalizers = {
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
}
|
||||
if (
|
||||
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
||||
):
|
||||
logger.warning(
|
||||
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
||||
"when using the default overlap scheduler. They will be ignored. "
|
||||
"Please add `--disable-overlap` when launching the server if you need these features. "
|
||||
"The speed will be slower in that case."
|
||||
)
|
||||
else:
|
||||
penalizers = {
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
}
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.device,
|
||||
Penalizers=penalizers,
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.temperatures)
|
||||
|
||||
def update_penalties(self):
|
||||
self.scaling_penalties = None
|
||||
self.linear_penalties = None
|
||||
|
||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||
if not penalizer.is_prepared():
|
||||
continue
|
||||
|
||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
else:
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
if not self.grammars:
|
||||
self.vocab_mask = None
|
||||
self.apply_mask = None
|
||||
self.apply_mask_func = None
|
||||
return
|
||||
|
||||
# find a grammar from the list
|
||||
# Find a grammar from the list
|
||||
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
||||
|
||||
# maybe we can reuse the existing mask?
|
||||
# TODO(lianmin): Maybe we can reuse the existing mask?
|
||||
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=len(self.temperatures),
|
||||
device=self.device,
|
||||
)
|
||||
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
|
||||
self.apply_mask_func = (
|
||||
first_grammar.apply_vocab_mask
|
||||
) # force to use static method
|
||||
|
||||
# Apply the mask
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
|
||||
# Move the mask to the device if needed
|
||||
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
def update_penalties(self):
|
||||
if self.penalizer_orchestrator.is_required:
|
||||
self.linear_penalty = torch.zeros(
|
||||
(len(self.temperatures), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.temperatures.device,
|
||||
)
|
||||
self.penalizer_orchestrator.apply(self.linear_penalty)
|
||||
else:
|
||||
self.linear_penalty = None
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
if self.linear_penalty is not None:
|
||||
# Used in the overlap mode
|
||||
logits.add_(self.linear_penalty)
|
||||
|
||||
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
||||
# Used in the non-overlap mode
|
||||
self.penalizer_orchestrator.apply(logits)
|
||||
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
||||
|
||||
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(keep_indices_device)
|
||||
|
||||
if self.has_custom_logit_processor:
|
||||
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
||||
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
value = getattr(self, item, None)
|
||||
if value is not None: # logit_bias can be None
|
||||
setattr(self, item, value[new_indices])
|
||||
setattr(self, item, value[keep_indices_device])
|
||||
|
||||
def _filter_batch_custom_logit_processor(
|
||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
||||
):
|
||||
"""Filter the custom logit processor and custom params"""
|
||||
|
||||
self.custom_logit_processor = {
|
||||
k: (p, mask[new_indices])
|
||||
k: (p, mask[keep_indices_device])
|
||||
for k, (p, mask) in self.custom_logit_processor.items()
|
||||
if any(
|
||||
mask[new_indices]
|
||||
if torch.any(
|
||||
mask[keep_indices_device]
|
||||
) # ignore the custom logit processor whose mask is all False
|
||||
}
|
||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||
self.custom_params = [self.custom_params[i] for i in keep_indices]
|
||||
|
||||
# If the custom logit processor is an empty dict, set the flag to False,
|
||||
# and set the custom logit processor and custom params to None.
|
||||
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
|
||||
self.custom_params = None
|
||||
self.has_custom_logit_processor = False
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor,
|
||||
rhs: torch.Tensor,
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: int = 0,
|
||||
):
|
||||
# bias tensor can be None
|
||||
if lhs is not None or rhs is not None:
|
||||
shape, dtype = None, None
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
with torch.dtype(dtype):
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def merge_custom_logit_processor(
|
||||
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
@@ -332,11 +275,6 @@ class SamplingBatchInfo:
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
# Merge the logit bias tensor
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
@@ -370,22 +308,5 @@ class SamplingBatchInfo:
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if self.linear_penalties is not None:
|
||||
logits.add_(self.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if self.scaling_penalties is not None:
|
||||
apply_scaling_penalties(logits, self.scaling_penalties)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||
self.is_all_greedy |= other.is_all_greedy
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
Reference in New Issue
Block a user