Fix the output of hidden states after HTTP requests (#4269)

This commit is contained in:
Qiaolin Yu
2025-03-13 17:54:06 -04:00
committed by GitHub
parent 5fe79605a8
commit 85d2365d33
5 changed files with 47 additions and 9 deletions

View File

@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately. So avoid getting hidden states and completions alternately.
""" """
import torch
import sglang as sgl import sglang as sgl
@@ -31,11 +33,29 @@ def main():
outputs = llm.generate( outputs = llm.generate(
prompts, sampling_params=sampling_params, return_hidden_states=True prompts, sampling_params=sampling_params, return_hidden_states=True
) )
llm.shutdown()
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================") print("===============================")
print( print(
f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}" f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
) )
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print() print()

View File

@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately.
""" """
import requests import requests
import torch
from sglang.test.test_utils import is_in_ci from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server from sglang.utils import print_highlight, terminate_process, wait_for_server
@@ -50,20 +51,31 @@ def main():
json=json_data, json=json_data,
) )
terminate_process(server_process)
outputs = response.json() outputs = response.json()
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================") print("===============================")
print( print(
f"Prompt: {prompt}\n" f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n" f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t" f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}\n" f"Completion_tokens: {output['meta_info']['completion_tokens']}"
f"Hidden states: {output['meta_info']['hidden_states']}"
) )
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print() print()
terminate_process(server_process)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -361,7 +361,7 @@ class Req:
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx self.output_token_ids_logprobs_idx
) = None ) = None
self.hidden_states = [] self.hidden_states: List[List[float]] = []
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None

View File

@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
] ]
.cpu() .cpu()
.clone() .clone()
.tolist()
) )
if req.grammar is not None: if req.grammar is not None:
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
) )
if req.return_hidden_states and logits_output.hidden_states is not None: if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) req.hidden_states.append(
logits_output.hidden_states[i].cpu().clone().tolist()
)
if req.grammar is not None and batch.spec_algorithm.is_none(): if req.grammar is not None and batch.spec_algorithm.is_none():
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)

View File

@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase):
for output in outputs: for output in outputs:
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8) self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
for hidden_state in output["meta_info"]["hidden_states"]: for i in range(len(output["meta_info"]["hidden_states"])):
self.assertIsInstance(hidden_state, torch.Tensor) assert isinstance(output["meta_info"]["hidden_states"][i], list)
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
# Checks that splicing of the batch was done correctly # Checks that splicing of the batch was done correctly
self.assertGreater( self.assertGreater(
outputs[1]["meta_info"]["hidden_states"][0].shape[0], outputs[1]["meta_info"]["hidden_states"][0].shape[0],