From 40782f05d7b94e171f39381fa4c606d952568e60 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Sat, 1 Mar 2025 20:51:29 -0500 Subject: [PATCH] Refactor: Move return_hidden_states to the generate input (#3985) Co-authored-by: Beichen-Ma --- docs/backend/native_api.ipynb | 2 +- docs/backend/sampling_params.md | 3 +- examples/runtime/engine/hidden_states.py | 5 ++-- python/sglang/srt/entrypoints/engine.py | 2 ++ python/sglang/srt/managers/io_struct.py | 8 ++++++ python/sglang/srt/managers/schedule_batch.py | 11 +++++++- python/sglang/srt/managers/scheduler.py | 22 +++++++-------- .../sglang/srt/managers/tokenizer_manager.py | 1 + .../srt/model_executor/cuda_graph_runner.py | 4 +-- .../srt/sampling/sampling_batch_info.py | 10 ------- python/sglang/srt/sampling/sampling_params.py | 2 -- test/srt/test_hidden_states.py | 28 ++++++++++--------- 12 files changed, 54 insertions(+), 44 deletions(-) diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 4e0daa0df..cf76871a7 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -57,7 +57,7 @@ "metadata": {}, "source": [ "## Generate (text generation model)\n", - "Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](../references/sampling_params.md)." + "Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](https://docs.sglang.ai/backend/sampling_params.html)." ] }, { diff --git a/docs/backend/sampling_params.md b/docs/backend/sampling_params.md index 3d3d9f9e3..662f423bb 100644 --- a/docs/backend/sampling_params.md +++ b/docs/backend/sampling_params.md @@ -17,6 +17,7 @@ The `/generate` endpoint accepts the following parameters in JSON format. For in * `stream`: Whether to stream the output. `bool = False` * `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None` * `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None` +* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False` ## Sampling params @@ -55,8 +56,6 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan * `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False` * `skip_special_tokens`: Remove special tokens during decoding. `bool = True` * `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None` -* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False` - ### Custom Logit Processor diff --git a/examples/runtime/engine/hidden_states.py b/examples/runtime/engine/hidden_states.py index 9c7b89b74..8c6747a91 100644 --- a/examples/runtime/engine/hidden_states.py +++ b/examples/runtime/engine/hidden_states.py @@ -26,10 +26,11 @@ def main(): "temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10, - "return_hidden_states": True, } - outputs = llm.generate(prompts, sampling_params=sampling_params) + outputs = llm.generate( + prompts, sampling_params=sampling_params, return_hidden_states=True + ) for prompt, output in zip(prompts, outputs): print("===============================") print( diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b85d93d58..68bdf2cba 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -123,6 +123,7 @@ class Engine: top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, stream: bool = False, ) -> Union[Dict, Iterator[Dict]]: """ @@ -144,6 +145,7 @@ class Engine: lora_path=lora_path, modalities=modalities_list, custom_logit_processor=custom_logit_processor, + return_hidden_states=return_hidden_states, stream=stream, ) loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index fb7cc53ce..e105ba943 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -69,11 +69,15 @@ class GenerateReqInput: # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # Use the processor's `to_str()` method to generate the serialized string. custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None + # Whether to return hidden states + return_hidden_states: bool = False + def normalize_batch_and_arguments(self): if ( self.text is None and self.input_ids is None and self.input_embeds is None @@ -218,6 +222,7 @@ class GenerateReqInput: if self.custom_logit_processor is not None else None ), + return_hidden_states=self.return_hidden_states, ) @@ -255,6 +260,9 @@ class TokenizedGenerateReqInput: # Use the processor's `to_str()` method to generate the serialized string. custom_logit_processor: Optional[str] = None + # Whether to return hidden states + return_hidden_states: bool = False + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a0db44c71..3d34cefb1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -236,6 +236,7 @@ class Req: input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, eos_token_ids: Optional[Set[int]] = None, ): # Input and output info @@ -256,7 +257,9 @@ class Req: # Sampling info self.sampling_params = sampling_params + self.custom_logit_processor = custom_logit_processor + self.return_hidden_states = return_hidden_states # Memory pool info self.req_pool_idx = None @@ -608,6 +611,9 @@ class ScheduleBatch: # Enable custom logit processor enable_custom_logit_processor: bool = False + # Whether to return hidden states + return_hidden_states: bool = False + @classmethod def init_new( cls, @@ -619,6 +625,7 @@ class ScheduleBatch: enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, + return_hidden_states: bool = False, ): return cls( reqs=reqs, @@ -633,6 +640,7 @@ class ScheduleBatch: device=req_to_token_pool.device, spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, + return_hidden_states=return_hidden_states, ) def batch_size(self): @@ -1153,6 +1161,7 @@ class ScheduleBatch: self.return_logprob |= other.return_logprob self.has_stream |= other.has_stream self.has_grammar |= other.has_grammar + self.return_hidden_states |= other.return_hidden_states if self.spec_info: self.spec_info.merge_batch(other.spec_info) @@ -1201,7 +1210,7 @@ class ScheduleBatch: spec_info=self.spec_info, capture_hidden_mode=( CaptureHiddenMode.FULL - if self.sampling_info.return_hidden_states + if self.return_hidden_states else ( getattr( self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e94f6fc96..1de73137f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -631,6 +631,7 @@ class Scheduler: lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, custom_logit_processor=custom_logit_processor, + return_hidden_states=recv_req.return_hidden_states, eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer @@ -947,9 +948,11 @@ class Scheduler: if self.running_batch is not None else set([]) ) - + return_hidden_states = False # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: + if req.return_hidden_states: + return_hidden_states = True if ( self.lora_paths and len( @@ -1035,6 +1038,7 @@ class Scheduler: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, + return_hidden_states, ) new_batch.prepare_for_extend() @@ -1226,7 +1230,7 @@ class Scheduler: i, req, logprob_pt, next_token_ids, logits_output ) if ( - req.sampling_params.return_hidden_states + req.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append( @@ -1333,10 +1337,7 @@ class Scheduler: logits_output.next_token_top_logprobs_idx[i] ) - if ( - req.sampling_params.return_hidden_states - and logits_output.hidden_states is not None - ): + if req.return_hidden_states and logits_output.hidden_states is not None: req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) if req.grammar is not None: @@ -1462,10 +1463,7 @@ class Scheduler: completion_tokens = [] cached_tokens = [] spec_verify_ct = [] - return_hidden_states = any( - req.sampling_params.return_hidden_states for req in reqs - ) - output_hidden_states = [] if return_hidden_states else None + output_hidden_states = None if return_logprob: input_token_logprobs_val = [] @@ -1532,7 +1530,9 @@ class Scheduler: output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) - if req.sampling_params.return_hidden_states: + if req.return_hidden_states: + if output_hidden_states is None: + output_hidden_states = [] output_hidden_states.append(req.hidden_states) # Send to detokenizer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 289a690f6..40348edc0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -383,6 +383,7 @@ class TokenizerManager: input_embeds=input_embeds, session_params=session_params, custom_logit_processor=obj.custom_logit_processor, + return_hidden_states=obj.return_hidden_states, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e8877e1f8..dd1b0da94 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -408,13 +408,13 @@ class CudaGraphRunner: ) # If the capture_hidden_mode changes, we need to recapture the graph if ( - forward_batch.sampling_info.return_hidden_states + forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL ): self.capture_hidden_mode = CaptureHiddenMode.FULL self.capture() elif ( - not forward_batch.sampling_info.return_hidden_states + forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL and self.capture_hidden_mode != hidden_mode_from_spec_info ): self.capture_hidden_mode = hidden_mode_from_spec_info diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6297e1fe0..393e713e9 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -37,9 +37,6 @@ class SamplingBatchInfo: # Whether any request has custom logit processor has_custom_logit_processor: bool - # Whether any request needs to return hidden states - return_hidden_states: bool - # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -94,9 +91,6 @@ class SamplingBatchInfo: and any(r.custom_logit_processor for r in reqs) # then check the requests. ) - # Check if any request needs to return hidden states - return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs) - if has_custom_logit_processor: # Merge the same type of custom logit processors together processor_dict = {} @@ -136,7 +130,6 @@ class SamplingBatchInfo: device=device, custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, - return_hidden_states=return_hidden_states, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -344,9 +337,6 @@ class SamplingBatchInfo: self.logit_bias, other.logit_bias, len(self), len(other), self.device ) - # Merge the return hidden states flag - self.return_hidden_states |= other.return_hidden_states - # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: # Merge the custom logit processors diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index a478be2ce..fa0ccaf37 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -49,7 +49,6 @@ class SamplingParams: no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, - return_hidden_states: bool = False, custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature @@ -75,7 +74,6 @@ class SamplingParams: self.ebnf = ebnf self.structural_tag = structural_tag self.no_stop_trim = no_stop_trim - self.return_hidden_states = return_hidden_states self.custom_params = custom_params # Process some special cases diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 5e39ea607..0deec49d7 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase): sampling_params = { "temperature": 0, "max_new_tokens": 8, - "return_hidden_states": True, } engine = sgl.Engine( @@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase): random_seed=42, skip_tokenizer_init=True, ) - outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) + outputs = engine.generate( + input_ids=input_ids, + sampling_params=sampling_params, + return_hidden_states=True, + ) engine.shutdown() for output in outputs: @@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(model_path) input_ids = tokenizer(prompts).input_ids - sample_completion = { + sampling_params = { "temperature": 0, "max_new_tokens": 8, - "return_hidden_states": True, - } - - sample_hidden_state = { - "temperature": 0, - "max_new_tokens": 8, - "return_hidden_states": False, } engine = sgl.Engine( @@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase): skip_tokenizer_init=True, ) outputs_completion_first_round = engine.generate( - input_ids=input_ids, sampling_params=sample_completion + input_ids=input_ids, + sampling_params=sampling_params, + return_hidden_states=True, ) outputs_hidden_state = engine.generate( - input_ids=input_ids, sampling_params=sample_hidden_state + input_ids=input_ids, + sampling_params=sampling_params, + return_hidden_states=False, ) outputs_completion_last_round = engine.generate( - input_ids=input_ids, sampling_params=sample_completion + input_ids=input_ids, + sampling_params=sampling_params, + return_hidden_states=True, ) engine.shutdown()