From 85d2365d337ca81eb353645bca15a199cc348847 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Thu, 13 Mar 2025 17:54:06 -0400 Subject: [PATCH] Fix the output of hidden states after HTTP requests (#4269) --- .../hidden_states/hidden_states_engine.py | 22 ++++++++++++++++++- .../hidden_states/hidden_states_server.py | 20 +++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 2 +- .../scheduler_output_processor_mixin.py | 5 ++++- test/srt/test_hidden_states.py | 7 ++++-- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/examples/runtime/hidden_states/hidden_states_engine.py b/examples/runtime/hidden_states/hidden_states_engine.py index 8c6747a91..8af883ab1 100644 --- a/examples/runtime/hidden_states/hidden_states_engine.py +++ b/examples/runtime/hidden_states/hidden_states_engine.py @@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit. So avoid getting hidden states and completions alternately. """ +import torch + import sglang as sgl @@ -31,11 +33,29 @@ def main(): outputs = llm.generate( prompts, sampling_params=sampling_params, return_hidden_states=True ) + + llm.shutdown() + for prompt, output in zip(prompts, outputs): + for i in range(len(output["meta_info"]["hidden_states"])): + output["meta_info"]["hidden_states"][i] = torch.tensor( + output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16 + ) print("===============================") print( - f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}" + f"Prompt: {prompt}\n" + f"Generated text: {output['text']}\n" + f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t" + f"Completion_tokens: {output['meta_info']['completion_tokens']}" ) + print("Hidden states: ") + hidden_states = torch.cat( + [ + i.unsqueeze(0) if len(i.shape) == 1 else i + for i in output["meta_info"]["hidden_states"] + ] + ) + print(hidden_states) print() diff --git a/examples/runtime/hidden_states/hidden_states_server.py b/examples/runtime/hidden_states/hidden_states_server.py index a198c7c23..39b4e464e 100644 --- a/examples/runtime/hidden_states/hidden_states_server.py +++ b/examples/runtime/hidden_states/hidden_states_server.py @@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately. """ import requests +import torch from sglang.test.test_utils import is_in_ci from sglang.utils import print_highlight, terminate_process, wait_for_server @@ -50,20 +51,31 @@ def main(): json=json_data, ) + terminate_process(server_process) + outputs = response.json() for prompt, output in zip(prompts, outputs): + for i in range(len(output["meta_info"]["hidden_states"])): + output["meta_info"]["hidden_states"][i] = torch.tensor( + output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16 + ) print("===============================") print( f"Prompt: {prompt}\n" f"Generated text: {output['text']}\n" f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t" - f"Completion_tokens: {output['meta_info']['completion_tokens']}\n" - f"Hidden states: {output['meta_info']['hidden_states']}" + f"Completion_tokens: {output['meta_info']['completion_tokens']}" ) + print("Hidden states: ") + hidden_states = torch.cat( + [ + i.unsqueeze(0) if len(i.shape) == 1 else i + for i in output["meta_info"]["hidden_states"] + ] + ) + print(hidden_states) print() - terminate_process(server_process) - if __name__ == "__main__": main() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1f00bd646..27f87d8a2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -361,7 +361,7 @@ class Req: ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( self.output_token_ids_logprobs_idx ) = None - self.hidden_states = [] + self.hidden_states: List[List[float]] = [] # Embedding (return values) self.embedding = None diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e83dc0646..13158d937 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin: ] .cpu() .clone() + .tolist() ) if req.grammar is not None: @@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin: ) if req.return_hidden_states and logits_output.hidden_states is not None: - req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) + req.hidden_states.append( + logits_output.hidden_states[i].cpu().clone().tolist() + ) if req.grammar is not None and batch.spec_algorithm.is_none(): req.grammar.accept_token(next_token_id) diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 4c28b3139..87676c0ad 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase): for output in outputs: self.assertEqual(len(output["meta_info"]["hidden_states"]), 8) - for hidden_state in output["meta_info"]["hidden_states"]: - self.assertIsInstance(hidden_state, torch.Tensor) + for i in range(len(output["meta_info"]["hidden_states"])): + assert isinstance(output["meta_info"]["hidden_states"][i], list) + output["meta_info"]["hidden_states"][i] = torch.tensor( + output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16 + ) # Checks that splicing of the batch was done correctly self.assertGreater( outputs[1]["meta_info"]["hidden_states"][0].shape[0],