Improve code style of sampler (#1168)
This commit is contained in:
136
python/sglang/srt/sampling/sampling_batch_info.py
Normal file
136
python/sglang/srt/sampling/sampling_batch_info.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Basic Info
|
||||
vocab_size: int
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
device = "cuda"
|
||||
reqs = batch.reqs
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
|
||||
ret.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
|
||||
ret.update_regex_vocab_mask(batch)
|
||||
|
||||
return ret
|
||||
|
||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||
bs, reqs = batch.batch_size(), batch.reqs
|
||||
device = "cuda"
|
||||
has_regex = any(req.regex_fsm is not None for req in reqs)
|
||||
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
for i, req in enumerate(reqs):
|
||||
if req.regex_fsm is not None:
|
||||
if self.vocab_mask is None:
|
||||
self.vocab_mask = torch.zeros(
|
||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
||||
)
|
||||
self.vocab_mask[i][
|
||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||
] = 1
|
||||
|
||||
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
143
python/sglang/srt/sampling/sampling_params.py
Normal file
143
python/sglang/srt/sampling/sampling_params.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 128,
|
||||
min_new_tokens: int = 0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.stop_strs = stop
|
||||
self.stop_token_ids = {*stop_token_ids}
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(
|
||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||
)
|
||||
if not 0.0 <= self.repetition_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be in (0, 2], got "
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if not 0 <= self.min_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.max_new_tokens is not None:
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
if not self.min_new_tokens <= self.max_new_tokens:
|
||||
raise ValueError(
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
if self.stop_strs is None:
|
||||
self.stop_strs = []
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_str_max_len = 0
|
||||
else:
|
||||
self.stop_str_max_len = 1
|
||||
else:
|
||||
if isinstance(self.stop_strs, str):
|
||||
self.stop_strs = [self.stop_strs]
|
||||
|
||||
stop_str_max_len = 0
|
||||
for stop_str in self.stop_strs:
|
||||
if tokenizer is not None:
|
||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||
else:
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"stop": self.stop_strs,
|
||||
"stop_token_ids": list(self.stop_token_ids),
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
}
|
||||
Reference in New Issue
Block a user