[Feat] Return hidden states (experimental) (#3364)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -210,6 +210,7 @@ class DetokenizerManager:
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
output_hidden_states=recv_obj.output_hidden_states,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
@@ -397,6 +399,8 @@ class BatchStrOut:
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
|
||||
@@ -315,6 +315,7 @@ class Req:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
@@ -604,6 +605,9 @@ class ScheduleBatch:
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -615,6 +619,7 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -629,6 +634,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1196,9 +1202,15 @@ class ScheduleBatch:
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
||||
if self.spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
CaptureHiddenMode.FULL
|
||||
if self.return_hidden_states
|
||||
else (
|
||||
getattr(
|
||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
if self.spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -997,6 +997,7 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
self.server_args.return_hidden_states,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1156,6 +1157,8 @@ class Scheduler:
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
|
||||
hidden_state_offset = 0
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
@@ -1182,6 +1185,21 @@ class Scheduler:
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
|
||||
if (
|
||||
self.server_args.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(
|
||||
logits_output.hidden_states[
|
||||
hidden_state_offset : (
|
||||
hidden_state_offset := hidden_state_offset
|
||||
+ len(req.origin_input_ids)
|
||||
)
|
||||
]
|
||||
.cpu()
|
||||
.clone()
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
@@ -1275,6 +1293,12 @@ class Scheduler:
|
||||
logits_output.next_token_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if (
|
||||
self.server_args.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
@@ -1398,6 +1422,7 @@ class Scheduler:
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
hidden_states = []
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1464,6 +1489,8 @@ class Scheduler:
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
|
||||
hidden_states.append(req.hidden_states)
|
||||
|
||||
# Send to detokenizer
|
||||
if rids:
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
@@ -1490,6 +1517,7 @@ class Scheduler:
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
hidden_states,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
@@ -1553,6 +1581,7 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
self.server_args.return_hidden_states,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
@@ -796,6 +796,12 @@ class TokenizerManager:
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(recv_obj, "output_hidden_states")
|
||||
and len(recv_obj.output_hidden_states[i]) > 0
|
||||
):
|
||||
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
|
||||
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if logits_output.hidden_states is not None:
|
||||
logits_output.hidden_states = logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_done.record()
|
||||
|
||||
|
||||
@@ -349,7 +349,13 @@ class CudaGraphRunner:
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=(
|
||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||
CaptureHiddenMode.FULL
|
||||
if self.model_runner.server_args.return_hidden_states
|
||||
else (
|
||||
spec_info.capture_hidden_mode
|
||||
if spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ class ServerArgs:
|
||||
delete_ckpt_after_loading: bool = False
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# Custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
@@ -896,6 +897,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--return-hidden-states",
|
||||
action="store_true",
|
||||
help="Return hidden states in the response.",
|
||||
)
|
||||
# Function Calling
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
|
||||
Reference in New Issue
Block a user