diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 72cbf60f5..1fd007e86 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -103,7 +103,7 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 25 + timeout-minutes: 30 run: | RANGE=${{ matrix.range }} range_begin=${RANGE%-*} diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 58d24ac3f..3a2700856 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -179,6 +179,59 @@ "asyncio.run(main())" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Return Hidden States" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "\n", + "llm = sgl.Engine(\n", + " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + "]\n", + "\n", + "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n", + "\n", + "outputs = llm.generate(prompts, sampling_params=sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print(\"===============================\")\n", + " print(\n", + " f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n", + " )\n", + " print()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index a8ded73bc..2708e5874 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -210,6 +210,7 @@ class DetokenizerManager: input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + output_hidden_states=recv_obj.output_hidden_states, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f7419d04f..67225cf84 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -371,6 +371,8 @@ class BatchTokenIDOut: output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] + output_hidden_states: List[List[float]] + @dataclass class BatchStrOut: @@ -397,6 +399,8 @@ class BatchStrOut: output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] + output_hidden_states: List[List[float]] + @dataclass class BatchEmbeddingOut: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f22d3d5fe..ecac38656 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -315,6 +315,7 @@ class Req: self.output_token_logprobs_val = self.output_token_logprobs_idx = ( self.output_top_logprobs_val ) = self.output_top_logprobs_idx = None + self.hidden_states = [] # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens @@ -604,6 +605,9 @@ class ScheduleBatch: # Enable custom logit processor enable_custom_logit_processor: bool = False + # Return hidden states + return_hidden_states: bool = False + @classmethod def init_new( cls, @@ -615,6 +619,7 @@ class ScheduleBatch: enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, + return_hidden_states: bool = False, ): return cls( reqs=reqs, @@ -629,6 +634,7 @@ class ScheduleBatch: device=req_to_token_pool.device, spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, + return_hidden_states=return_hidden_states, ) def batch_size(self): @@ -1196,9 +1202,15 @@ class ScheduleBatch: spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, capture_hidden_mode=( - getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL) - if self.spec_info - else CaptureHiddenMode.NULL + CaptureHiddenMode.FULL + if self.return_hidden_states + else ( + getattr( + self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + if self.spec_info + else CaptureHiddenMode.NULL + ) ), ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 79d4db114..9c00c8b25 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -997,6 +997,7 @@ class Scheduler: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, + self.server_args.return_hidden_states, ) new_batch.prepare_for_extend() @@ -1156,6 +1157,8 @@ class Scheduler: logits_output.input_token_logprobs.tolist() ) + hidden_state_offset = 0 + # Check finish conditions logprob_pt = 0 for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): @@ -1182,6 +1185,21 @@ class Scheduler: i, req, logprob_pt, next_token_ids, logits_output ) + if ( + self.server_args.return_hidden_states + and logits_output.hidden_states is not None + ): + req.hidden_states.append( + logits_output.hidden_states[ + hidden_state_offset : ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) + ] + .cpu() + .clone() + ) + if req.grammar is not None: req.grammar.accept_token(next_token_id) req.grammar.finished = req.finished() @@ -1275,6 +1293,12 @@ class Scheduler: logits_output.next_token_top_logprobs_idx[i] ) + if ( + self.server_args.return_hidden_states + and logits_output.hidden_states is not None + ): + req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) + if req.grammar is not None: req.grammar.accept_token(next_token_id) req.grammar.finished = req.finished() @@ -1398,6 +1422,7 @@ class Scheduler: completion_tokens = [] cached_tokens = [] spec_verify_ct = [] + hidden_states = [] if return_logprob: input_token_logprobs_val = [] @@ -1464,6 +1489,8 @@ class Scheduler: output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) + hidden_states.append(req.hidden_states) + # Send to detokenizer if rids: self.send_to_detokenizer.send_pyobj( @@ -1490,6 +1517,7 @@ class Scheduler: input_top_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, + hidden_states, ) ) else: # embedding or reward model @@ -1553,6 +1581,7 @@ class Scheduler: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, + self.server_args.return_hidden_states, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 53e1f4eda..90da2f103 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -796,6 +796,12 @@ class TokenizerManager: } ) + if ( + hasattr(recv_obj, "output_hidden_states") + and len(recv_obj.output_hidden_states[i]) > 0 + ): + meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + if isinstance(recv_obj, BatchStrOut): out_dict = { "text": recv_obj.output_strs[i], diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 961b0bbdc..26d8d5748 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -156,6 +156,10 @@ class TpModelWorkerClient: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) ) + if logits_output.hidden_states is not None: + logits_output.hidden_states = logits_output.hidden_states.to( + "cpu", non_blocking=True + ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index fc892ddf0..249bf82bd 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -349,7 +349,13 @@ class CudaGraphRunner: spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=( - spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + CaptureHiddenMode.FULL + if self.model_runner.server_args.return_hidden_states + else ( + spec_info.capture_hidden_mode + if spec_info + else CaptureHiddenMode.NULL + ) ), ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9f67c2dba..89f2c1b5b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -160,6 +160,7 @@ class ServerArgs: delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False allow_auto_truncate: bool = False + return_hidden_states: bool = False # Custom logit processor enable_custom_logit_processor: bool = False @@ -896,6 +897,11 @@ class ServerArgs: action="store_true", help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) + parser.add_argument( + "--return-hidden-states", + action="store_true", + help="Return hidden states in the response.", + ) # Function Calling parser.add_argument( "--tool-call-parser", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 039fde96a..d263bc113 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -46,6 +46,7 @@ suites = { "test_torchao.py", "test_triton_attention_kernels.py", "test_triton_attention_backend.py", + "test_hidden_states.py", "test_update_weights_from_disk.py", "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py new file mode 100644 index 000000000..5b17ebbf0 --- /dev/null +++ b/test/srt/test_hidden_states.py @@ -0,0 +1,77 @@ +import unittest + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import sglang as sgl +from sglang.test.test_utils import is_in_ci + + +class TestHiddenState(unittest.TestCase): + def test_return_hidden_states(self): + prompts = ["Today is", "Today is a sunny day and I like"] + model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_path) + input_ids = tokenizer(prompts).input_ids + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + return_hidden_states=True, + skip_tokenizer_init=True, + ) + outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) + engine.shutdown() + + for output in outputs: + self.assertEqual(len(output["meta_info"]["hidden_states"]), 8) + for hidden_state in output["meta_info"]["hidden_states"]: + self.assertIsInstance(hidden_state, torch.Tensor) + # Checks that splicing of the batch was done correctly + self.assertGreater( + outputs[1]["meta_info"]["hidden_states"][0].shape[0], + outputs[0]["meta_info"]["hidden_states"][0].shape[0], + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16, device_map="cuda" + ) + + for input_id, output in zip(input_ids, outputs): + with torch.inference_mode(): + hf_out = model( + torch.tensor( + [input_id + output["token_ids"][:-1]], device=model.device + ), + output_hidden_states=True, + ) + print("=== HF Hiddens ===") + print(hf_out["hidden_states"][-1][0]) + sg_hidden_states = torch.cat( + [ + i.unsqueeze(0) if len(i.shape) == 1 else i + for i in output["meta_info"]["hidden_states"] + ] + ).to("cuda") + print("=== SRT Hiddens ===") + print(sg_hidden_states) + + print( + f"Max diff: {torch.max(torch.abs(hf_out['hidden_states'][-1][0] - sg_hidden_states))}" + ) + + atol = 0.8 if is_in_ci() else 0.4 + self.assertTrue( + torch.allclose( + hf_out["hidden_states"][-1][0], + sg_hidden_states, + atol=atol, + rtol=0, + ) + ) + + +if __name__ == "__main__": + unittest.main()