[Feat] Return hidden states (experimental) (#3364)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Jackmin801
2025-02-10 15:54:37 -08:00
committed by GitHub
parent 2f47d710ae
commit 5f0e7de339
12 changed files with 204 additions and 5 deletions

View File

@@ -210,6 +210,7 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
)
)

View File

@@ -371,6 +371,8 @@ class BatchTokenIDOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
output_hidden_states: List[List[float]]
@dataclass
class BatchStrOut:
@@ -397,6 +399,8 @@ class BatchStrOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
output_hidden_states: List[List[float]]
@dataclass
class BatchEmbeddingOut:

View File

@@ -315,6 +315,7 @@ class Req:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
@@ -604,6 +605,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Return hidden states
return_hidden_states: bool = False
@classmethod
def init_new(
cls,
@@ -615,6 +619,7 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -629,6 +634,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
def batch_size(self):
@@ -1196,9 +1202,15 @@ class ScheduleBatch:
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if self.spec_info
else CaptureHiddenMode.NULL
)
),
)

View File

@@ -997,6 +997,7 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
new_batch.prepare_for_extend()
@@ -1156,6 +1157,8 @@ class Scheduler:
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
@@ -1182,6 +1185,21 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output
)
if (
self.server_args.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(
logits_output.hidden_states[
hidden_state_offset : (
hidden_state_offset := hidden_state_offset
+ len(req.origin_input_ids)
)
]
.cpu()
.clone()
)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
@@ -1275,6 +1293,12 @@ class Scheduler:
logits_output.next_token_top_logprobs_idx[i]
)
if (
self.server_args.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
@@ -1398,6 +1422,7 @@ class Scheduler:
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
hidden_states = []
if return_logprob:
input_token_logprobs_val = []
@@ -1464,6 +1489,8 @@ class Scheduler:
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
hidden_states.append(req.hidden_states)
# Send to detokenizer
if rids:
self.send_to_detokenizer.send_pyobj(
@@ -1490,6 +1517,7 @@ class Scheduler:
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
hidden_states,
)
)
else: # embedding or reward model
@@ -1553,6 +1581,7 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
idle_batch.prepare_for_idle()
return idle_batch

View File

@@ -796,6 +796,12 @@ class TokenizerManager:
}
)
if (
hasattr(recv_obj, "output_hidden_states")
and len(recv_obj.output_hidden_states[i]) > 0
):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],

View File

@@ -156,6 +156,10 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()

View File

@@ -349,7 +349,13 @@ class CudaGraphRunner:
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.model_runner.server_args.return_hidden_states
else (
spec_info.capture_hidden_mode
if spec_info
else CaptureHiddenMode.NULL
)
),
)

View File

@@ -160,6 +160,7 @@ class ServerArgs:
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False
# Custom logit processor
enable_custom_logit_processor: bool = False
@@ -896,6 +897,11 @@ class ServerArgs:
action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
)
parser.add_argument(
"--return-hidden-states",
action="store_true",
help="Return hidden states in the response.",
)
# Function Calling
parser.add_argument(
"--tool-call-parser",