[Feat] Return hidden states (experimental) (#3364)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 25
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
RANGE=${{ matrix.range }}
|
RANGE=${{ matrix.range }}
|
||||||
range_begin=${RANGE%-*}
|
range_begin=${RANGE%-*}
|
||||||
|
|||||||
@@ -179,6 +179,59 @@
|
|||||||
"asyncio.run(main())"
|
"asyncio.run(main())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm.shutdown()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Return Hidden States"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sglang as sgl\n",
|
||||||
|
"\n",
|
||||||
|
"llm = sgl.Engine(\n",
|
||||||
|
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompts = [\n",
|
||||||
|
" \"Hello, my name is\",\n",
|
||||||
|
" \"The president of the United States is\",\n",
|
||||||
|
" \"The capital of France is\",\n",
|
||||||
|
" \"The future of AI is\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n",
|
||||||
|
"\n",
|
||||||
|
"outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
|
||||||
|
"for prompt, output in zip(prompts, outputs):\n",
|
||||||
|
" print(\"===============================\")\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n",
|
||||||
|
" )\n",
|
||||||
|
" print()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ class DetokenizerManager:
|
|||||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||||
|
output_hidden_states=recv_obj.output_hidden_states,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
|
|||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
|
|
||||||
|
output_hidden_states: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
@@ -397,6 +399,8 @@ class BatchStrOut:
|
|||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
|
|
||||||
|
output_hidden_states: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchEmbeddingOut:
|
class BatchEmbeddingOut:
|
||||||
|
|||||||
@@ -315,6 +315,7 @@ class Req:
|
|||||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||||
self.output_top_logprobs_val
|
self.output_top_logprobs_val
|
||||||
) = self.output_top_logprobs_idx = None
|
) = self.output_top_logprobs_idx = None
|
||||||
|
self.hidden_states = []
|
||||||
|
|
||||||
# Logprobs (internal values)
|
# Logprobs (internal values)
|
||||||
# The tokens is prefilled but need to be considered as decode tokens
|
# The tokens is prefilled but need to be considered as decode tokens
|
||||||
@@ -604,6 +605,9 @@ class ScheduleBatch:
|
|||||||
# Enable custom logit processor
|
# Enable custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
|
|
||||||
|
# Return hidden states
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -615,6 +619,7 @@ class ScheduleBatch:
|
|||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
enable_custom_logit_processor: bool,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -629,6 +634,7 @@ class ScheduleBatch:
|
|||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1196,9 +1202,15 @@ class ScheduleBatch:
|
|||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
spec_info=self.spec_info,
|
spec_info=self.spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
CaptureHiddenMode.FULL
|
||||||
if self.spec_info
|
if self.return_hidden_states
|
||||||
else CaptureHiddenMode.NULL
|
else (
|
||||||
|
getattr(
|
||||||
|
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
if self.spec_info
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -997,6 +997,7 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
|
self.server_args.return_hidden_states,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1156,6 +1157,8 @@ class Scheduler:
|
|||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hidden_state_offset = 0
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
@@ -1182,6 +1185,21 @@ class Scheduler:
|
|||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server_args.return_hidden_states
|
||||||
|
and logits_output.hidden_states is not None
|
||||||
|
):
|
||||||
|
req.hidden_states.append(
|
||||||
|
logits_output.hidden_states[
|
||||||
|
hidden_state_offset : (
|
||||||
|
hidden_state_offset := hidden_state_offset
|
||||||
|
+ len(req.origin_input_ids)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
.cpu()
|
||||||
|
.clone()
|
||||||
|
)
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
@@ -1275,6 +1293,12 @@ class Scheduler:
|
|||||||
logits_output.next_token_top_logprobs_idx[i]
|
logits_output.next_token_top_logprobs_idx[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server_args.return_hidden_states
|
||||||
|
and logits_output.hidden_states is not None
|
||||||
|
):
|
||||||
|
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
@@ -1398,6 +1422,7 @@ class Scheduler:
|
|||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
spec_verify_ct = []
|
spec_verify_ct = []
|
||||||
|
hidden_states = []
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val = []
|
input_token_logprobs_val = []
|
||||||
@@ -1464,6 +1489,8 @@ class Scheduler:
|
|||||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||||
|
|
||||||
|
hidden_states.append(req.hidden_states)
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if rids:
|
if rids:
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
@@ -1490,6 +1517,7 @@ class Scheduler:
|
|||||||
input_top_logprobs_idx,
|
input_top_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
|
hidden_states,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
@@ -1553,6 +1581,7 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
|
self.server_args.return_hidden_states,
|
||||||
)
|
)
|
||||||
idle_batch.prepare_for_idle()
|
idle_batch.prepare_for_idle()
|
||||||
return idle_batch
|
return idle_batch
|
||||||
|
|||||||
@@ -796,6 +796,12 @@ class TokenizerManager:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(recv_obj, "output_hidden_states")
|
||||||
|
and len(recv_obj.output_hidden_states[i]) > 0
|
||||||
|
):
|
||||||
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
if isinstance(recv_obj, BatchStrOut):
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"text": recv_obj.output_strs[i],
|
"text": recv_obj.output_strs[i],
|
||||||
|
|||||||
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
|
|||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||||
)
|
)
|
||||||
|
if logits_output.hidden_states is not None:
|
||||||
|
logits_output.hidden_states = logits_output.hidden_states.to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_done.record()
|
copy_done.record()
|
||||||
|
|
||||||
|
|||||||
@@ -349,7 +349,13 @@ class CudaGraphRunner:
|
|||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
CaptureHiddenMode.FULL
|
||||||
|
if self.model_runner.server_args.return_hidden_states
|
||||||
|
else (
|
||||||
|
spec_info.capture_hidden_mode
|
||||||
|
if spec_info
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class ServerArgs:
|
|||||||
delete_ckpt_after_loading: bool = False
|
delete_ckpt_after_loading: bool = False
|
||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
allow_auto_truncate: bool = False
|
allow_auto_truncate: bool = False
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
# Custom logit processor
|
# Custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
@@ -896,6 +897,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-hidden-states",
|
||||||
|
action="store_true",
|
||||||
|
help="Return hidden states in the response.",
|
||||||
|
)
|
||||||
# Function Calling
|
# Function Calling
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ suites = {
|
|||||||
"test_torchao.py",
|
"test_torchao.py",
|
||||||
"test_triton_attention_kernels.py",
|
"test_triton_attention_kernels.py",
|
||||||
"test_triton_attention_backend.py",
|
"test_triton_attention_backend.py",
|
||||||
|
"test_hidden_states.py",
|
||||||
"test_update_weights_from_disk.py",
|
"test_update_weights_from_disk.py",
|
||||||
"test_update_weights_from_tensor.py",
|
"test_update_weights_from_tensor.py",
|
||||||
"test_vision_chunked_prefill.py",
|
"test_vision_chunked_prefill.py",
|
||||||
|
|||||||
77
test/srt/test_hidden_states.py
Normal file
77
test/srt/test_hidden_states.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import is_in_ci
|
||||||
|
|
||||||
|
|
||||||
|
class TestHiddenState(unittest.TestCase):
|
||||||
|
def test_return_hidden_states(self):
|
||||||
|
prompts = ["Today is", "Today is a sunny day and I like"]
|
||||||
|
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=model_path,
|
||||||
|
random_seed=42,
|
||||||
|
return_hidden_states=True,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
)
|
||||||
|
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
|
||||||
|
for hidden_state in output["meta_info"]["hidden_states"]:
|
||||||
|
self.assertIsInstance(hidden_state, torch.Tensor)
|
||||||
|
# Checks that splicing of the batch was done correctly
|
||||||
|
self.assertGreater(
|
||||||
|
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
|
||||||
|
outputs[0]["meta_info"]["hidden_states"][0].shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
for input_id, output in zip(input_ids, outputs):
|
||||||
|
with torch.inference_mode():
|
||||||
|
hf_out = model(
|
||||||
|
torch.tensor(
|
||||||
|
[input_id + output["token_ids"][:-1]], device=model.device
|
||||||
|
),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
print("=== HF Hiddens ===")
|
||||||
|
print(hf_out["hidden_states"][-1][0])
|
||||||
|
sg_hidden_states = torch.cat(
|
||||||
|
[
|
||||||
|
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||||
|
for i in output["meta_info"]["hidden_states"]
|
||||||
|
]
|
||||||
|
).to("cuda")
|
||||||
|
print("=== SRT Hiddens ===")
|
||||||
|
print(sg_hidden_states)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Max diff: {torch.max(torch.abs(hf_out['hidden_states'][-1][0] - sg_hidden_states))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 0.8 if is_in_ci() else 0.4
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
hf_out["hidden_states"][-1][0],
|
||||||
|
sg_hidden_states,
|
||||||
|
atol=atol,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user