[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
38
python/sglang/srt/sampling/custom_logit_processor.py
Normal file
38
python/sglang/srt/sampling/custom_logit_processor.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import dill
|
||||
import torch
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _cache_from_str(json_str: str):
|
||||
"""Deserialize a json string to a Callable object.
|
||||
This function is cached to avoid redundant deserialization.
|
||||
"""
|
||||
data = json.loads(json_str)
|
||||
return dill.loads(bytes.fromhex(data["callable"]))
|
||||
|
||||
|
||||
class CustomLogitProcessor(ABC):
|
||||
"""Abstract base class for callable functions."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Define the callable behavior."""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Serialize the callable function to a JSON-compatible string."""
|
||||
return json.dumps({"callable": dill.dumps(self).hex()})
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, json_str: str):
|
||||
"""Deserialize a callable function from a JSON string."""
|
||||
return _cache_from_str(json_str)
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,6 +14,7 @@ if is_cuda:
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
# Custom Parameters
|
||||
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||
|
||||
# Custom Logit Processor
|
||||
custom_logit_processor: Optional[
|
||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||
] = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||
@@ -76,6 +88,36 @@ class SamplingBatchInfo:
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
processor_dict = {}
|
||||
for i, r in enumerate(reqs):
|
||||
if r.custom_logit_processor is None:
|
||||
continue
|
||||
processor_str = r.custom_logit_processor
|
||||
if processor_str not in processor_dict:
|
||||
processor_dict[processor_str] = []
|
||||
processor_dict[processor_str].append(i)
|
||||
|
||||
merged_custom_logit_processor = {
|
||||
hash(processor_str): (
|
||||
# The deserialized custom logit processor object
|
||||
CustomLogitProcessor.from_str(processor_str),
|
||||
# The mask tensor for the requests that use this custom logit processor
|
||||
torch.zeros(len(reqs), dtype=torch.bool)
|
||||
.scatter_(0, torch.tensor(true_indices), True)
|
||||
.to(device, non_blocking=True),
|
||||
)
|
||||
for processor_str, true_indices in processor_dict.items()
|
||||
}
|
||||
custom_params = [r.sampling_params.custom_params for r in reqs]
|
||||
else:
|
||||
merged_custom_logit_processor = None
|
||||
custom_params = None
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
top_ps=top_ps,
|
||||
@@ -83,8 +125,11 @@ class SamplingBatchInfo:
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
has_custom_logit_processor=has_custom_logit_processor,
|
||||
vocab_size=vocab_size,
|
||||
device=device,
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -184,6 +229,8 @@ class SamplingBatchInfo:
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
if self.has_custom_logit_processor:
|
||||
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
@@ -196,6 +243,26 @@ class SamplingBatchInfo:
|
||||
if value is not None: # logit_bias can be None
|
||||
setattr(self, item, value[new_indices])
|
||||
|
||||
def _filter_batch_custom_logit_processor(
|
||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||
):
|
||||
"""Filter the custom logit processor and custom params"""
|
||||
if not self.custom_logit_processor:
|
||||
return
|
||||
self.custom_logit_processor = {
|
||||
k: (p, mask[new_indices])
|
||||
for k, (p, mask) in self.custom_logit_processor.items()
|
||||
if any(
|
||||
mask[new_indices]
|
||||
) # ignore the custom logit processor whose mask is all False
|
||||
}
|
||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||
|
||||
if len(self) == 0:
|
||||
self.custom_logit_processor = None
|
||||
self.custom_params = None
|
||||
self.has_custom_logit_processor = False
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor,
|
||||
@@ -221,6 +288,39 @@ class SamplingBatchInfo:
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def merge_custom_logit_processor(
|
||||
lhs: Optional[Dict[str, torch.Tensor]],
|
||||
rhs: Optional[Dict[str, torch.Tensor]],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
):
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
lhs, rhs = lhs or {}, rhs or {}
|
||||
|
||||
keys = set(lhs.keys()).union(set(rhs.keys()))
|
||||
merged_dict = {}
|
||||
|
||||
for k in keys:
|
||||
# Get the logit processor object
|
||||
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
||||
# Get and merge the mask tensors from the two dicts
|
||||
left_mask = (
|
||||
lhs[k][1]
|
||||
if k in lhs
|
||||
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
||||
)
|
||||
right_mask = (
|
||||
rhs[k][1]
|
||||
if k in rhs
|
||||
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
||||
)
|
||||
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
||||
|
||||
return merged_dict
|
||||
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
@@ -240,6 +340,26 @@ class SamplingBatchInfo:
|
||||
)
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
self.custom_logit_processor = (
|
||||
SamplingBatchInfo.merge_custom_logit_processor(
|
||||
self.custom_logit_processor,
|
||||
other.custom_logit_processor,
|
||||
len(self),
|
||||
len(other),
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
# Merge the custom params lists
|
||||
self.custom_params = self.custom_params or [None] * len(self)
|
||||
other.custom_params = other.custom_params or [None] * len(other)
|
||||
self.custom_params.extend(other.custom_params)
|
||||
|
||||
# Set the flag to True if any of the two has custom logit processor
|
||||
self.has_custom_logit_processor = True
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
|
||||
@@ -48,6 +48,7 @@ class SamplingParams:
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
@@ -71,6 +72,7 @@ class SamplingParams:
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.custom_params = custom_params
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
|
||||
Reference in New Issue
Block a user