Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -631,6 +631,7 @@ class Scheduler:
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
@@ -947,9 +948,11 @@ class Scheduler:
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
|
||||
return_hidden_states = False
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if req.return_hidden_states:
|
||||
return_hidden_states = True
|
||||
if (
|
||||
self.lora_paths
|
||||
and len(
|
||||
@@ -1035,6 +1038,7 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
return_hidden_states,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1226,7 +1230,7 @@ class Scheduler:
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
if (
|
||||
req.sampling_params.return_hidden_states
|
||||
req.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(
|
||||
@@ -1333,10 +1337,7 @@ class Scheduler:
|
||||
logits_output.next_token_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if (
|
||||
req.sampling_params.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
if req.return_hidden_states and logits_output.hidden_states is not None:
|
||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||
|
||||
if req.grammar is not None:
|
||||
@@ -1462,10 +1463,7 @@ class Scheduler:
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
return_hidden_states = any(
|
||||
req.sampling_params.return_hidden_states for req in reqs
|
||||
)
|
||||
output_hidden_states = [] if return_hidden_states else None
|
||||
output_hidden_states = None
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1532,7 +1530,9 @@ class Scheduler:
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
|
||||
if req.sampling_params.return_hidden_states:
|
||||
if req.return_hidden_states:
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = []
|
||||
output_hidden_states.append(req.hidden_states)
|
||||
|
||||
# Send to detokenizer
|
||||
|
||||
Reference in New Issue
Block a user