Fix the output of hidden states after HTTP requests (#4269)
This commit is contained in:
@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit.
|
||||
So avoid getting hidden states and completions alternately.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@@ -31,11 +33,29 @@ def main():
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params=sampling_params, return_hidden_states=True
|
||||
)
|
||||
|
||||
llm.shutdown()
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||
)
|
||||
print("===============================")
|
||||
print(
|
||||
f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}"
|
||||
f"Prompt: {prompt}\n"
|
||||
f"Generated text: {output['text']}\n"
|
||||
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||
)
|
||||
print("Hidden states: ")
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||
for i in output["meta_info"]["hidden_states"]
|
||||
]
|
||||
)
|
||||
print(hidden_states)
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||
@@ -50,20 +51,31 @@ def main():
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
outputs = response.json()
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||
)
|
||||
print("===============================")
|
||||
print(
|
||||
f"Prompt: {prompt}\n"
|
||||
f"Generated text: {output['text']}\n"
|
||||
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}\n"
|
||||
f"Hidden states: {output['meta_info']['hidden_states']}"
|
||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||
)
|
||||
print("Hidden states: ")
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||
for i in output["meta_info"]["hidden_states"]
|
||||
]
|
||||
)
|
||||
print(hidden_states)
|
||||
print()
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user