Fix the output of hidden states after HTTP requests (#4269)

This commit is contained in:
Qiaolin Yu
2025-03-13 17:54:06 -04:00
committed by GitHub
parent 5fe79605a8
commit 85d2365d33
5 changed files with 47 additions and 9 deletions

View File

@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit.
So avoid getting hidden states and completions alternately.
"""
import torch
import sglang as sgl
@@ -31,11 +33,29 @@ def main():
outputs = llm.generate(
prompts, sampling_params=sampling_params, return_hidden_states=True
)
llm.shutdown()
for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================")
print(
f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}"
f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
)
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print()

View File

@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately.
"""
import requests
import torch
from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server
@@ -50,20 +51,31 @@ def main():
json=json_data,
)
terminate_process(server_process)
outputs = response.json()
for prompt, output in zip(prompts, outputs):
for i in range(len(output["meta_info"]["hidden_states"])):
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
print("===============================")
print(
f"Prompt: {prompt}\n"
f"Generated text: {output['text']}\n"
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
f"Completion_tokens: {output['meta_info']['completion_tokens']}\n"
f"Hidden states: {output['meta_info']['hidden_states']}"
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
)
print("Hidden states: ")
hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
)
print(hidden_states)
print()
terminate_process(server_process)
if __name__ == "__main__":
main()