Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -631,6 +631,7 @@ class Scheduler:
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
@@ -947,9 +948,11 @@ class Scheduler:
if self.running_batch is not None
else set([])
)
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if (
self.lora_paths
and len(
@@ -1035,6 +1038,7 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
return_hidden_states,
)
new_batch.prepare_for_extend()
@@ -1226,7 +1230,7 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output
)
if (
req.sampling_params.return_hidden_states
req.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(
@@ -1333,10 +1337,7 @@ class Scheduler:
logits_output.next_token_top_logprobs_idx[i]
)
if (
req.sampling_params.return_hidden_states
and logits_output.hidden_states is not None
):
if req.return_hidden_states and logits_output.hidden_states is not None:
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
if req.grammar is not None:
@@ -1462,10 +1463,7 @@ class Scheduler:
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
return_hidden_states = any(
req.sampling_params.return_hidden_states for req in reqs
)
output_hidden_states = [] if return_hidden_states else None
output_hidden_states = None
if return_logprob:
input_token_logprobs_val = []
@@ -1532,7 +1530,9 @@ class Scheduler:
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
if req.sampling_params.return_hidden_states:
if req.return_hidden_states:
if output_hidden_states is None:
output_hidden_states = []
output_hidden_states.append(req.hidden_states)
# Send to detokenizer