Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -123,6 +123,7 @@ class Engine:
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, Iterator[Dict]]:
"""
@@ -144,6 +145,7 @@ class Engine:
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
loop = asyncio.get_event_loop()

View File

@@ -69,11 +69,15 @@ class GenerateReqInput:
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
# Whether to return hidden states
return_hidden_states: bool = False
def normalize_batch_and_arguments(self):
if (
self.text is None and self.input_ids is None and self.input_embeds is None
@@ -218,6 +222,7 @@ class GenerateReqInput:
if self.custom_logit_processor is not None
else None
),
return_hidden_states=self.return_hidden_states,
)
@@ -255,6 +260,9 @@ class TokenizedGenerateReqInput:
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None
# Whether to return hidden states
return_hidden_states: bool = False
@dataclass
class EmbeddingReqInput:

View File

@@ -236,6 +236,7 @@ class Req:
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
@@ -256,7 +257,9 @@ class Req:
# Sampling info
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# Memory pool info
self.req_pool_idx = None
@@ -608,6 +611,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Whether to return hidden states
return_hidden_states: bool = False
@classmethod
def init_new(
cls,
@@ -619,6 +625,7 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -633,6 +640,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
def batch_size(self):
@@ -1153,6 +1161,7 @@ class ScheduleBatch:
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar
self.return_hidden_states |= other.return_hidden_states
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
@@ -1201,7 +1210,7 @@ class ScheduleBatch:
spec_info=self.spec_info,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.sampling_info.return_hidden_states
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL

View File

@@ -631,6 +631,7 @@ class Scheduler:
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
@@ -947,9 +948,11 @@ class Scheduler:
if self.running_batch is not None
else set([])
)
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if (
self.lora_paths
and len(
@@ -1035,6 +1038,7 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
return_hidden_states,
)
new_batch.prepare_for_extend()
@@ -1226,7 +1230,7 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output
)
if (
req.sampling_params.return_hidden_states
req.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(
@@ -1333,10 +1337,7 @@ class Scheduler:
logits_output.next_token_top_logprobs_idx[i]
)
if (
req.sampling_params.return_hidden_states
and logits_output.hidden_states is not None
):
if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
if req.grammar is not None:
@@ -1462,10 +1463,7 @@ class Scheduler:
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
return_hidden_states = any(
req.sampling_params.return_hidden_states for req in reqs
)
output_hidden_states = [] if return_hidden_states else None
output_hidden_states = None
if return_logprob:
input_token_logprobs_val = []
@@ -1532,7 +1530,9 @@ class Scheduler:
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
if req.sampling_params.return_hidden_states:
if req.return_hidden_states:
if output_hidden_states is None:
output_hidden_states = []
output_hidden_states.append(req.hidden_states)
# Send to detokenizer

View File

@@ -383,6 +383,7 @@ class TokenizerManager:
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(

View File

@@ -408,13 +408,13 @@ class CudaGraphRunner:
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.sampling_info.return_hidden_states
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
):
self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture()
elif (
not forward_batch.sampling_info.return_hidden_states
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
and self.capture_hidden_mode != hidden_mode_from_spec_info
):
self.capture_hidden_mode = hidden_mode_from_spec_info

View File

@@ -37,9 +37,6 @@ class SamplingBatchInfo:
# Whether any request has custom logit processor
has_custom_logit_processor: bool
# Whether any request needs to return hidden states
return_hidden_states: bool
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
@@ -94,9 +91,6 @@ class SamplingBatchInfo:
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
# Check if any request needs to return hidden states
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
processor_dict = {}
@@ -136,7 +130,6 @@ class SamplingBatchInfo:
device=device,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -344,9 +337,6 @@ class SamplingBatchInfo:
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
# Merge the return hidden states flag
self.return_hidden_states |= other.return_hidden_states
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors

View File

@@ -49,7 +49,6 @@ class SamplingParams:
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
return_hidden_states: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.temperature = temperature
@@ -75,7 +74,6 @@ class SamplingParams:
self.ebnf = ebnf
self.structural_tag = structural_tag
self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states
self.custom_params = custom_params
# Process some special cases