Fix the output of hidden states after HTTP requests (#4269)
This commit is contained in:
@@ -7,6 +7,8 @@ the cuda graph will be recaptured, which might lead to a performance hit.
|
|||||||
So avoid getting hidden states and completions alternately.
|
So avoid getting hidden states and completions alternately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
@@ -31,11 +33,29 @@ def main():
|
|||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts, sampling_params=sampling_params, return_hidden_states=True
|
prompts, sampling_params=sampling_params, return_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm.shutdown()
|
||||||
|
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||||
|
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||||
|
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||||
|
)
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(
|
print(
|
||||||
f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}"
|
f"Prompt: {prompt}\n"
|
||||||
|
f"Generated text: {output['text']}\n"
|
||||||
|
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||||
|
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||||
)
|
)
|
||||||
|
print("Hidden states: ")
|
||||||
|
hidden_states = torch.cat(
|
||||||
|
[
|
||||||
|
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||||
|
for i in output["meta_info"]["hidden_states"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(hidden_states)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ So avoid getting hidden states and completions alternately.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.test.test_utils import is_in_ci
|
from sglang.test.test_utils import is_in_ci
|
||||||
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||||
@@ -50,20 +51,31 @@ def main():
|
|||||||
json=json_data,
|
json=json_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
terminate_process(server_process)
|
||||||
|
|
||||||
outputs = response.json()
|
outputs = response.json()
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||||
|
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||||
|
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||||
|
)
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(
|
print(
|
||||||
f"Prompt: {prompt}\n"
|
f"Prompt: {prompt}\n"
|
||||||
f"Generated text: {output['text']}\n"
|
f"Generated text: {output['text']}\n"
|
||||||
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}\n"
|
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||||
f"Hidden states: {output['meta_info']['hidden_states']}"
|
|
||||||
)
|
)
|
||||||
|
print("Hidden states: ")
|
||||||
|
hidden_states = torch.cat(
|
||||||
|
[
|
||||||
|
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||||
|
for i in output["meta_info"]["hidden_states"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(hidden_states)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
terminate_process(server_process)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -361,7 +361,7 @@ class Req:
|
|||||||
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
||||||
self.output_token_ids_logprobs_idx
|
self.output_token_ids_logprobs_idx
|
||||||
) = None
|
) = None
|
||||||
self.hidden_states = []
|
self.hidden_states: List[List[float]] = []
|
||||||
|
|
||||||
# Embedding (return values)
|
# Embedding (return values)
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
]
|
]
|
||||||
.cpu()
|
.cpu()
|
||||||
.clone()
|
.clone()
|
||||||
|
.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if req.return_hidden_states and logits_output.hidden_states is not None:
|
if req.return_hidden_states and logits_output.hidden_states is not None:
|
||||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
req.hidden_states.append(
|
||||||
|
logits_output.hidden_states[i].cpu().clone().tolist()
|
||||||
|
)
|
||||||
|
|
||||||
if req.grammar is not None and batch.spec_algorithm.is_none():
|
if req.grammar is not None and batch.spec_algorithm.is_none():
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
|
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
|
||||||
for hidden_state in output["meta_info"]["hidden_states"]:
|
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||||
self.assertIsInstance(hidden_state, torch.Tensor)
|
assert isinstance(output["meta_info"]["hidden_states"][i], list)
|
||||||
|
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||||
|
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||||
|
)
|
||||||
# Checks that splicing of the batch was done correctly
|
# Checks that splicing of the batch was done correctly
|
||||||
self.assertGreater(
|
self.assertGreater(
|
||||||
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
|
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
|
||||||
|
|||||||
Reference in New Issue
Block a user