[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@@ -69,6 +70,8 @@ class GenerateReqInput:
|
||||
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# Custom logit processor (serialized function)
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (
|
||||
@@ -183,6 +186,13 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
self.custom_logit_processor = [self.custom_logit_processor] * num
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
def regenerate_rid(self):
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
@@ -202,6 +212,11 @@ class GenerateReqInput:
|
||||
log_metrics=self.log_metrics,
|
||||
modalities=self.modalities[i] if self.modalities else None,
|
||||
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
||||
custom_logit_processor=(
|
||||
self.custom_logit_processor[i]
|
||||
if self.custom_logit_processor is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[SessionParams] = None
|
||||
|
||||
# Custom logit processor (serialized function)
|
||||
# TODO (hpguo): Add an example and update doc string here
|
||||
custom_logit_processor: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -232,6 +232,7 @@ class Req:
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
custom_logit_processor: Optional[str] = None,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -252,6 +253,7 @@ class Req:
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
|
||||
@@ -614,6 +614,19 @@ class Scheduler:
|
||||
fake_input_ids = [1] * seq_length
|
||||
recv_req.input_ids = fake_input_ids
|
||||
|
||||
# Handle custom logit processor passed to the request
|
||||
custom_logit_processor = recv_req.custom_logit_processor
|
||||
if (
|
||||
not self.server_args.enable_custom_logit_processor
|
||||
and custom_logit_processor is not None
|
||||
):
|
||||
logger.warning(
|
||||
"The SGLang server is not configured to enable custom logit processor."
|
||||
"The custom logit processor passed in will be ignored."
|
||||
"Please set --enable-custom-logits-processor to enable this feature."
|
||||
)
|
||||
custom_logit_processor = None
|
||||
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
@@ -624,6 +637,7 @@ class Scheduler:
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -131,6 +131,7 @@ class Session:
|
||||
sampling_params=req.sampling_params,
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
custom_logit_processor=req.custom_logit_processor,
|
||||
)
|
||||
if last_req is not None:
|
||||
new_req.image_inputs = last_req.image_inputs
|
||||
|
||||
@@ -381,6 +381,7 @@ class TokenizerManager:
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
session_params=session_params,
|
||||
custom_logit_processor=obj.custom_logit_processor,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
Reference in New Issue
Block a user