[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||||
|
|
||||||
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
|
|||||||
):
|
):
|
||||||
logits = logits_output.next_token_logits
|
logits = logits_output.next_token_logits
|
||||||
|
|
||||||
|
# Apply the custom logit processors if registered in the sampling info.
|
||||||
|
if sampling_info.has_custom_logit_processor:
|
||||||
|
self._apply_custom_logit_processor(logits, sampling_info)
|
||||||
|
|
||||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
logits = torch.where(
|
logits = torch.where(
|
||||||
@@ -121,6 +126,29 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
def _apply_custom_logit_processor(
|
||||||
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||||
|
):
|
||||||
|
"""Apply custom logit processors to the logits.
|
||||||
|
This function will modify the logits in-place."""
|
||||||
|
|
||||||
|
for _, (
|
||||||
|
processor,
|
||||||
|
batch_mask,
|
||||||
|
) in sampling_batch_info.custom_logit_processor.items():
|
||||||
|
# Get the batch indices that need to be processed
|
||||||
|
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
# Apply the processor to the logits
|
||||||
|
logits[batch_mask] = processor(
|
||||||
|
logits[batch_mask],
|
||||||
|
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Custom logit processor {processor.__class__.__name__} is applied."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from enum import Enum
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +70,8 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
# Session info for continual prompting
|
# Session info for continual prompting
|
||||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
|
# Custom logit processor (serialized function)
|
||||||
|
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (
|
if (
|
||||||
@@ -183,6 +186,13 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
assert self.parallel_sample_num == 1
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
|
if self.custom_logit_processor is None:
|
||||||
|
self.custom_logit_processor = [None] * num
|
||||||
|
elif not isinstance(self.custom_logit_processor, list):
|
||||||
|
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
def regenerate_rid(self):
|
def regenerate_rid(self):
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
return self.rid
|
return self.rid
|
||||||
@@ -202,6 +212,11 @@ class GenerateReqInput:
|
|||||||
log_metrics=self.log_metrics,
|
log_metrics=self.log_metrics,
|
||||||
modalities=self.modalities[i] if self.modalities else None,
|
modalities=self.modalities[i] if self.modalities else None,
|
||||||
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
||||||
|
custom_logit_processor=(
|
||||||
|
self.custom_logit_processor[i]
|
||||||
|
if self.custom_logit_processor is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
|
|||||||
# Session info for continual prompting
|
# Session info for continual prompting
|
||||||
session_params: Optional[SessionParams] = None
|
session_params: Optional[SessionParams] = None
|
||||||
|
|
||||||
|
# Custom logit processor (serialized function)
|
||||||
|
# TODO (hpguo): Add an example and update doc string here
|
||||||
|
custom_logit_processor: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ class Req:
|
|||||||
lora_path: Optional[str] = None,
|
lora_path: Optional[str] = None,
|
||||||
input_embeds: Optional[List[List[float]]] = None,
|
input_embeds: Optional[List[List[float]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
custom_logit_processor: Optional[str] = None,
|
||||||
eos_token_ids: Optional[Set[int]] = None,
|
eos_token_ids: Optional[Set[int]] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
@@ -252,6 +253,7 @@ class Req:
|
|||||||
# Sampling info
|
# Sampling info
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
self.custom_logit_processor = custom_logit_processor
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
|
|||||||
@@ -614,6 +614,19 @@ class Scheduler:
|
|||||||
fake_input_ids = [1] * seq_length
|
fake_input_ids = [1] * seq_length
|
||||||
recv_req.input_ids = fake_input_ids
|
recv_req.input_ids = fake_input_ids
|
||||||
|
|
||||||
|
# Handle custom logit processor passed to the request
|
||||||
|
custom_logit_processor = recv_req.custom_logit_processor
|
||||||
|
if (
|
||||||
|
not self.server_args.enable_custom_logit_processor
|
||||||
|
and custom_logit_processor is not None
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"The SGLang server is not configured to enable custom logit processor."
|
||||||
|
"The custom logit processor passed in will be ignored."
|
||||||
|
"Please set --enable-custom-logits-processor to enable this feature."
|
||||||
|
)
|
||||||
|
custom_logit_processor = None
|
||||||
|
|
||||||
req = Req(
|
req = Req(
|
||||||
recv_req.rid,
|
recv_req.rid,
|
||||||
recv_req.input_text,
|
recv_req.input_text,
|
||||||
@@ -624,6 +637,7 @@ class Scheduler:
|
|||||||
stream=recv_req.stream,
|
stream=recv_req.stream,
|
||||||
lora_path=recv_req.lora_path,
|
lora_path=recv_req.lora_path,
|
||||||
input_embeds=recv_req.input_embeds,
|
input_embeds=recv_req.input_embeds,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ class Session:
|
|||||||
sampling_params=req.sampling_params,
|
sampling_params=req.sampling_params,
|
||||||
lora_path=req.lora_path,
|
lora_path=req.lora_path,
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
|
custom_logit_processor=req.custom_logit_processor,
|
||||||
)
|
)
|
||||||
if last_req is not None:
|
if last_req is not None:
|
||||||
new_req.image_inputs = last_req.image_inputs
|
new_req.image_inputs = last_req.image_inputs
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ class TokenizerManager:
|
|||||||
lora_path=obj.lora_path,
|
lora_path=obj.lora_path,
|
||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
session_params=session_params,
|
session_params=session_params,
|
||||||
|
custom_logit_processor=obj.custom_logit_processor,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
|
|||||||
38
python/sglang/srt/sampling/custom_logit_processor.py
Normal file
38
python/sglang/srt/sampling/custom_logit_processor.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import dill
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _cache_from_str(json_str: str):
|
||||||
|
"""Deserialize a json string to a Callable object.
|
||||||
|
This function is cached to avoid redundant deserialization.
|
||||||
|
"""
|
||||||
|
data = json.loads(json_str)
|
||||||
|
return dill.loads(bytes.fromhex(data["callable"]))
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLogitProcessor(ABC):
|
||||||
|
"""Abstract base class for callable functions."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Define the callable behavior."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_str(self) -> str:
|
||||||
|
"""Serialize the callable function to a JSON-compatible string."""
|
||||||
|
return json.dumps({"callable": dill.dumps(self).hex()})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, json_str: str):
|
||||||
|
"""Deserialize a callable function from a JSON string."""
|
||||||
|
return _cache_from_str(json_str)
|
||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -14,6 +14,7 @@ if is_cuda:
|
|||||||
from sgl_kernel import sampling_scaling_penalties
|
from sgl_kernel import sampling_scaling_penalties
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
|
|||||||
# Dispatch in CUDA graph
|
# Dispatch in CUDA graph
|
||||||
need_min_p_sampling: bool
|
need_min_p_sampling: bool
|
||||||
|
|
||||||
|
# Whether any request has custom logit processor
|
||||||
|
has_custom_logit_processor: bool
|
||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
|
|||||||
# Device
|
# Device
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
|
# Custom Parameters
|
||||||
|
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||||
|
|
||||||
|
# Custom Logit Processor
|
||||||
|
custom_logit_processor: Optional[
|
||||||
|
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||||
|
] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||||
@@ -76,6 +88,36 @@ class SamplingBatchInfo:
|
|||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# Check if any request has custom logit processor
|
||||||
|
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
|
||||||
|
|
||||||
|
if has_custom_logit_processor:
|
||||||
|
# Merge the same type of custom logit processors together
|
||||||
|
processor_dict = {}
|
||||||
|
for i, r in enumerate(reqs):
|
||||||
|
if r.custom_logit_processor is None:
|
||||||
|
continue
|
||||||
|
processor_str = r.custom_logit_processor
|
||||||
|
if processor_str not in processor_dict:
|
||||||
|
processor_dict[processor_str] = []
|
||||||
|
processor_dict[processor_str].append(i)
|
||||||
|
|
||||||
|
merged_custom_logit_processor = {
|
||||||
|
hash(processor_str): (
|
||||||
|
# The deserialized custom logit processor object
|
||||||
|
CustomLogitProcessor.from_str(processor_str),
|
||||||
|
# The mask tensor for the requests that use this custom logit processor
|
||||||
|
torch.zeros(len(reqs), dtype=torch.bool)
|
||||||
|
.scatter_(0, torch.tensor(true_indices), True)
|
||||||
|
.to(device, non_blocking=True),
|
||||||
|
)
|
||||||
|
for processor_str, true_indices in processor_dict.items()
|
||||||
|
}
|
||||||
|
custom_params = [r.sampling_params.custom_params for r in reqs]
|
||||||
|
else:
|
||||||
|
merged_custom_logit_processor = None
|
||||||
|
custom_params = None
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
temperatures=temperatures,
|
temperatures=temperatures,
|
||||||
top_ps=top_ps,
|
top_ps=top_ps,
|
||||||
@@ -83,8 +125,11 @@ class SamplingBatchInfo:
|
|||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||||
|
has_custom_logit_processor=has_custom_logit_processor,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
device=device,
|
device=device,
|
||||||
|
custom_params=custom_params,
|
||||||
|
custom_logit_processor=merged_custom_logit_processor,
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
@@ -184,6 +229,8 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||||
|
if self.has_custom_logit_processor:
|
||||||
|
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
@@ -196,6 +243,26 @@ class SamplingBatchInfo:
|
|||||||
if value is not None: # logit_bias can be None
|
if value is not None: # logit_bias can be None
|
||||||
setattr(self, item, value[new_indices])
|
setattr(self, item, value[new_indices])
|
||||||
|
|
||||||
|
def _filter_batch_custom_logit_processor(
|
||||||
|
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||||
|
):
|
||||||
|
"""Filter the custom logit processor and custom params"""
|
||||||
|
if not self.custom_logit_processor:
|
||||||
|
return
|
||||||
|
self.custom_logit_processor = {
|
||||||
|
k: (p, mask[new_indices])
|
||||||
|
for k, (p, mask) in self.custom_logit_processor.items()
|
||||||
|
if any(
|
||||||
|
mask[new_indices]
|
||||||
|
) # ignore the custom logit processor whose mask is all False
|
||||||
|
}
|
||||||
|
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||||
|
|
||||||
|
if len(self) == 0:
|
||||||
|
self.custom_logit_processor = None
|
||||||
|
self.custom_params = None
|
||||||
|
self.has_custom_logit_processor = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
lhs: torch.Tensor,
|
lhs: torch.Tensor,
|
||||||
@@ -221,6 +288,39 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_custom_logit_processor(
|
||||||
|
lhs: Optional[Dict[str, torch.Tensor]],
|
||||||
|
rhs: Optional[Dict[str, torch.Tensor]],
|
||||||
|
bs1: int,
|
||||||
|
bs2: int,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
if lhs is None and rhs is None:
|
||||||
|
return None
|
||||||
|
lhs, rhs = lhs or {}, rhs or {}
|
||||||
|
|
||||||
|
keys = set(lhs.keys()).union(set(rhs.keys()))
|
||||||
|
merged_dict = {}
|
||||||
|
|
||||||
|
for k in keys:
|
||||||
|
# Get the logit processor object
|
||||||
|
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
||||||
|
# Get and merge the mask tensors from the two dicts
|
||||||
|
left_mask = (
|
||||||
|
lhs[k][1]
|
||||||
|
if k in lhs
|
||||||
|
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
||||||
|
)
|
||||||
|
right_mask = (
|
||||||
|
rhs[k][1]
|
||||||
|
if k in rhs
|
||||||
|
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
||||||
|
)
|
||||||
|
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
||||||
|
|
||||||
|
return merged_dict
|
||||||
|
|
||||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
@@ -240,6 +340,26 @@ class SamplingBatchInfo:
|
|||||||
)
|
)
|
||||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||||
|
|
||||||
|
# Merge the custom logit processors and custom params lists
|
||||||
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||||
|
# Merge the custom logit processors
|
||||||
|
self.custom_logit_processor = (
|
||||||
|
SamplingBatchInfo.merge_custom_logit_processor(
|
||||||
|
self.custom_logit_processor,
|
||||||
|
other.custom_logit_processor,
|
||||||
|
len(self),
|
||||||
|
len(other),
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Merge the custom params lists
|
||||||
|
self.custom_params = self.custom_params or [None] * len(self)
|
||||||
|
other.custom_params = other.custom_params or [None] * len(other)
|
||||||
|
self.custom_params.extend(other.custom_params)
|
||||||
|
|
||||||
|
# Set the flag to True if any of the two has custom logit processor
|
||||||
|
self.has_custom_logit_processor = True
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor):
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
# Apply logit_bias
|
# Apply logit_bias
|
||||||
if self.logit_bias is not None:
|
if self.logit_bias is not None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-6
|
_SAMPLING_EPS = 1e-6
|
||||||
|
|
||||||
@@ -48,6 +48,7 @@ class SamplingParams:
|
|||||||
no_stop_trim: bool = False,
|
no_stop_trim: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -71,6 +72,7 @@ class SamplingParams:
|
|||||||
self.json_schema = json_schema
|
self.json_schema = json_schema
|
||||||
self.ebnf = ebnf
|
self.ebnf = ebnf
|
||||||
self.no_stop_trim = no_stop_trim
|
self.no_stop_trim = no_stop_trim
|
||||||
|
self.custom_params = custom_params
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
|
|||||||
@@ -773,6 +773,7 @@ class Engine:
|
|||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
@@ -784,6 +785,7 @@ class Engine:
|
|||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the current event loop
|
# get the current event loop
|
||||||
@@ -824,6 +826,7 @@ class Engine:
|
|||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
custom_logit_processor: Optional[Union[str, List[str]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
@@ -835,6 +838,7 @@ class Engine:
|
|||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = await generate_request(obj, None)
|
ret = await generate_request(obj, None)
|
||||||
|
|||||||
@@ -159,6 +159,9 @@ class ServerArgs:
|
|||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
allow_auto_truncate: bool = False
|
allow_auto_truncate: bool = False
|
||||||
|
|
||||||
|
# Custom logit processor
|
||||||
|
enable_custom_logit_processor: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -865,6 +868,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-custom-logit-processor",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
@@ -24,7 +26,10 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=("--enable-custom-logit-processor",),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -248,6 +253,62 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(all(x is not None for x in logprobs))
|
self.assertTrue(all(x is not None for x in logprobs))
|
||||||
|
|
||||||
|
def run_custom_logit_processor(self, target_token_id: int):
|
||||||
|
"""Test custom logit processor with custom params."""
|
||||||
|
|
||||||
|
custom_params = {"token_id": target_token_id}
|
||||||
|
|
||||||
|
class DeterministicLogitProcessor(CustomLogitProcessor):
|
||||||
|
"""A dummy logit processor that changes the logits to always
|
||||||
|
sample the given token id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, logits, custom_param_list):
|
||||||
|
assert logits.shape[0] == len(custom_param_list)
|
||||||
|
key = "token_id"
|
||||||
|
|
||||||
|
for i, param_dict in enumerate(custom_param_list):
|
||||||
|
# Mask all other tokens
|
||||||
|
logits[i, :] = -float("inf")
|
||||||
|
# Assign highest probability to the specified token
|
||||||
|
logits[i, param_dict[key]] = 0.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||||
|
|
||||||
|
# Base case json data to be posted to the server.
|
||||||
|
base_json = {
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {"temperature": 0.0},
|
||||||
|
"return_logprob": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Custom json data with custom logit processor and params.
|
||||||
|
custom_json = base_json.copy()
|
||||||
|
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
|
||||||
|
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||||
|
|
||||||
|
custom_response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json=custom_json,
|
||||||
|
).json()
|
||||||
|
|
||||||
|
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
|
||||||
|
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||||
|
|
||||||
|
# The logit processor should always sample the given token as the logits is deterministic.
|
||||||
|
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
|
||||||
|
|
||||||
|
def test_custom_logit_processor(self):
|
||||||
|
"""Test custom logit processor with a single request."""
|
||||||
|
self.run_custom_logit_processor(target_token_id=5)
|
||||||
|
|
||||||
|
def test_custom_logit_processor_batch(self):
|
||||||
|
"""Test custom logit processor with a batch of requests."""
|
||||||
|
target_token_ids = list(range(32))
|
||||||
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
def test_get_server_info(self):
|
def test_get_server_info(self):
|
||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|||||||
Reference in New Issue
Block a user