Improve code style of sampler (#1168)

This commit is contained in:
Liangsheng Yin
2024-08-21 16:48:24 -07:00
committed by GitHub
parent ac1b74fa85
commit 83e23c69b3
10 changed files with 268 additions and 194 deletions

View File

@@ -0,0 +1,136 @@
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, List
import torch
import sglang.srt.sampling.penaltylib as penaltylib
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@dataclasses.dataclass
class SamplingBatchInfo:
# Basic Info
vocab_size: int
# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
reqs = batch.reqs
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
ret.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
ret.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)
# Handle logit bias but only allocate when needed
ret.logit_bias = None
ret.update_regex_vocab_mask(batch)
return ret
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"
has_regex = any(req.regex_fsm is not None for req in reqs)
# Reset the vocab mask
self.vocab_mask = None
if has_regex:
for i, req in enumerate(reqs):
if req.regex_fsm is not None:
if self.vocab_mask is None:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
)
self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])
def merge(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
)
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])

View File

@@ -0,0 +1,143 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-6
class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
regex: Optional[str] = None,
n: int = 1,
) -> None:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
self.stop_token_ids = {*stop_token_ids}
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.regex = regex
self.n = n
# Process some special cases
if self.temperature < _SAMPLING_EPS:
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary
def verify(self):
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError(
"frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}."
)
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
)
if not 0.0 <= self.repetition_penalty <= 2.0:
raise ValueError(
"repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}."
)
if not 0 <= self.min_new_tokens:
raise ValueError(
f"min_new_tokens must be in (0, max_new_tokens], got "
f"{self.min_new_tokens}."
)
if self.max_new_tokens is not None:
if self.max_new_tokens < 0:
raise ValueError(
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)
if not self.min_new_tokens <= self.max_new_tokens:
raise ValueError(
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
def normalize(self, tokenizer):
# Process stop strings
if self.stop_strs is None:
self.stop_strs = []
if self.stop_token_ids is None:
self.stop_str_max_len = 0
else:
self.stop_str_max_len = 1
else:
if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs]
stop_str_max_len = 0
for stop_str in self.stop_strs:
if tokenizer is not None:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop_strs,
"stop_token_ids": list(self.stop_token_ids),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}