From d6898dd2534a3faa801a7166062722288cc5c92a Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Thu, 27 Feb 2025 01:06:54 -0500 Subject: [PATCH] Add return hidden state in the native API (#3897) Co-authored-by: Beichen-Ma Co-authored-by: Chayenne --- docs/backend/sampling_params.md | 1 + examples/runtime/engine/hidden_states.py | 12 +++- python/sglang/srt/managers/schedule_batch.py | 7 +-- python/sglang/srt/managers/scheduler.py | 14 ++--- .../srt/model_executor/cuda_graph_runner.py | 34 +++++++---- .../srt/sampling/sampling_batch_info.py | 11 ++++ python/sglang/srt/sampling/sampling_params.py | 2 + python/sglang/srt/server_args.py | 6 -- test/srt/test_hidden_states.py | 59 ++++++++++++++++++- 9 files changed, 112 insertions(+), 34 deletions(-) diff --git a/docs/backend/sampling_params.md b/docs/backend/sampling_params.md index 91df324f4..ef8c8bb54 100644 --- a/docs/backend/sampling_params.md +++ b/docs/backend/sampling_params.md @@ -55,6 +55,7 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan * `ignore_eos`: Don't stop generation when EOS token is sampled. * `skip_special_tokens`: Remove special tokens during decoding. * `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. +* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. ### Custom Logit Processor diff --git a/examples/runtime/engine/hidden_states.py b/examples/runtime/engine/hidden_states.py index 50ec15151..9c7b89b74 100644 --- a/examples/runtime/engine/hidden_states.py +++ b/examples/runtime/engine/hidden_states.py @@ -2,7 +2,9 @@ Usage: python hidden_states.py -Note that we are actively working on moving return_hidden_states to the sampling_params. +Note that each time you change the `return_hidden_states` parameter, +the cuda graph will be recaptured, which might lead to a performance hit. +So avoid getting hidden states and completions alternately. """ import sglang as sgl @@ -18,10 +20,14 @@ def main(): # Create an LLM. llm = sgl.Engine( model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", - return_hidden_states=True, ) - sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10} + sampling_params = { + "temperature": 0.8, + "top_p": 0.95, + "max_new_tokens": 10, + "return_hidden_states": True, + } outputs = llm.generate(prompts, sampling_params=sampling_params) for prompt, output in zip(prompts, outputs): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4ffed10b..ea7280485 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -607,9 +607,6 @@ class ScheduleBatch: # Enable custom logit processor enable_custom_logit_processor: bool = False - # Return hidden states - return_hidden_states: bool = False - @classmethod def init_new( cls, @@ -621,7 +618,6 @@ class ScheduleBatch: enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, - return_hidden_states: bool = False, ): return cls( reqs=reqs, @@ -636,7 +632,6 @@ class ScheduleBatch: device=req_to_token_pool.device, spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, - return_hidden_states=return_hidden_states, ) def batch_size(self): @@ -1205,7 +1200,7 @@ class ScheduleBatch: spec_info=self.spec_info, capture_hidden_mode=( CaptureHiddenMode.FULL - if self.return_hidden_states + if self.sampling_info.return_hidden_states else ( getattr( self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e4a141a9c..ea6d67518 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1030,7 +1030,6 @@ class Scheduler: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, - self.server_args.return_hidden_states, ) new_batch.prepare_for_extend() @@ -1221,9 +1220,8 @@ class Scheduler: logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) - if ( - self.server_args.return_hidden_states + req.sampling_params.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append( @@ -1331,7 +1329,7 @@ class Scheduler: ) if ( - self.server_args.return_hidden_states + req.sampling_params.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) @@ -1459,7 +1457,10 @@ class Scheduler: completion_tokens = [] cached_tokens = [] spec_verify_ct = [] - output_hidden_states = [] if self.server_args.return_hidden_states else None + return_hidden_states = any( + req.sampling_params.return_hidden_states for req in reqs + ) + output_hidden_states = [] if return_hidden_states else None if return_logprob: input_token_logprobs_val = [] @@ -1526,7 +1527,7 @@ class Scheduler: output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) - if self.server_args.return_hidden_states: + if req.sampling_params.return_hidden_states: output_hidden_states.append(req.hidden_states) # Send to detokenizer @@ -1619,7 +1620,6 @@ class Scheduler: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, - self.server_args.return_hidden_states, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d3f2e5146..e8877e1f8 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very samll. We add more values here to make sure we capture the maximum bs. + # is very small. We add more values here to make sure we capture the maximum bs. capture_bs = list( sorted( set( @@ -175,6 +175,7 @@ class CudaGraphRunner: # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_forward_mode = ForwardMode.DECODE + self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: @@ -335,6 +336,10 @@ class CudaGraphRunner: gathered_buffer = None spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, @@ -355,15 +360,7 @@ class CudaGraphRunner: mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, - capture_hidden_mode=( - CaptureHiddenMode.FULL - if self.model_runner.server_args.return_hidden_states - else ( - spec_info.capture_hidden_mode - if spec_info - else CaptureHiddenMode.NULL - ) - ), + capture_hidden_mode=self.capture_hidden_mode, ) # Attention backend @@ -406,6 +403,23 @@ class CudaGraphRunner: def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + hidden_mode_from_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + # If the capture_hidden_mode changes, we need to recapture the graph + if ( + forward_batch.sampling_info.return_hidden_states + and self.capture_hidden_mode != CaptureHiddenMode.FULL + ): + self.capture_hidden_mode = CaptureHiddenMode.FULL + self.capture() + elif ( + not forward_batch.sampling_info.return_hidden_states + and self.capture_hidden_mode != hidden_mode_from_spec_info + ): + self.capture_hidden_mode = hidden_mode_from_spec_info + self.capture() + raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 9521a34f4..6297e1fe0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -37,6 +37,9 @@ class SamplingBatchInfo: # Whether any request has custom logit processor has_custom_logit_processor: bool + # Whether any request needs to return hidden states + return_hidden_states: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -91,6 +94,9 @@ class SamplingBatchInfo: and any(r.custom_logit_processor for r in reqs) # then check the requests. ) + # Check if any request needs to return hidden states + return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs) + if has_custom_logit_processor: # Merge the same type of custom logit processors together processor_dict = {} @@ -130,6 +136,7 @@ class SamplingBatchInfo: device=device, custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, + return_hidden_states=return_hidden_states, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -336,6 +343,10 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + + # Merge the return hidden states flag + self.return_hidden_states |= other.return_hidden_states + # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: # Merge the custom logit processors diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 2224fb091..0280f2be7 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -48,6 +48,7 @@ class SamplingParams: no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + return_hidden_states: bool = False, custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature @@ -72,6 +73,7 @@ class SamplingParams: self.json_schema = json_schema self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.return_hidden_states = return_hidden_states self.custom_params = custom_params # Process some special cases diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b3edd1f5b..fd2188dcc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -162,7 +162,6 @@ class ServerArgs: delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False allow_auto_truncate: bool = False - return_hidden_states: bool = False enable_custom_logit_processor: bool = False tool_call_parser: str = None enable_hierarchical_cache: bool = False @@ -917,11 +916,6 @@ class ServerArgs: action="store_true", help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) - parser.add_argument( - "--return-hidden-states", - action="store_true", - help="Return hidden states in the response.", - ) parser.add_argument( "--tool-call-parser", type=str, diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 219c04693..5e39ea607 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(model_path) input_ids = tokenizer(prompts).input_ids - sampling_params = {"temperature": 0, "max_new_tokens": 8} + sampling_params = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": True, + } engine = sgl.Engine( model_path=model_path, random_seed=42, - return_hidden_states=True, skip_tokenizer_init=True, ) outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) @@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase): ) ) + def test_repeatedly_changes_hidden_states(self): + prompts = ["Today is", "Today is a sunny day and I like"] + model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_path) + input_ids = tokenizer(prompts).input_ids + + sample_completion = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": True, + } + + sample_hidden_state = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": False, + } + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + skip_tokenizer_init=True, + ) + outputs_completion_first_round = engine.generate( + input_ids=input_ids, sampling_params=sample_completion + ) + outputs_hidden_state = engine.generate( + input_ids=input_ids, sampling_params=sample_hidden_state + ) + + outputs_completion_last_round = engine.generate( + input_ids=input_ids, sampling_params=sample_completion + ) + engine.shutdown() + + for ( + output_completion_first_round, + output_hidden_state, + output_completion_last_round, + ) in zip( + outputs_completion_first_round, + outputs_hidden_state, + outputs_completion_last_round, + ): + self.assertEqual( + len(output_completion_first_round["meta_info"]["hidden_states"]), 8 + ) + self.assertNotIn("hidden_states", output_hidden_state["meta_info"]) + self.assertEqual( + len(output_completion_last_round["meta_info"]["hidden_states"]), 8 + ) + if __name__ == "__main__": unittest.main()