Add return hidden state in the native API (#3897)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -607,9 +607,6 @@ class ScheduleBatch:
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -621,7 +618,6 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -636,7 +632,6 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1205,7 +1200,7 @@ class ScheduleBatch:
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
CaptureHiddenMode.FULL
|
||||
if self.return_hidden_states
|
||||
if self.sampling_info.return_hidden_states
|
||||
else (
|
||||
getattr(
|
||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
|
||||
@@ -1030,7 +1030,6 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
self.server_args.return_hidden_states,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1221,9 +1220,8 @@ class Scheduler:
|
||||
logprob_pt += self.add_logprob_return_values(
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
|
||||
if (
|
||||
self.server_args.return_hidden_states
|
||||
req.sampling_params.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(
|
||||
@@ -1331,7 +1329,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if (
|
||||
self.server_args.return_hidden_states
|
||||
req.sampling_params.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||
@@ -1459,7 +1457,10 @@ class Scheduler:
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
output_hidden_states = [] if self.server_args.return_hidden_states else None
|
||||
return_hidden_states = any(
|
||||
req.sampling_params.return_hidden_states for req in reqs
|
||||
)
|
||||
output_hidden_states = [] if return_hidden_states else None
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1526,7 +1527,7 @@ class Scheduler:
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
|
||||
if self.server_args.return_hidden_states:
|
||||
if req.sampling_params.return_hidden_states:
|
||||
output_hidden_states.append(req.hidden_states)
|
||||
|
||||
# Send to detokenizer
|
||||
@@ -1619,7 +1620,6 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
self.server_args.return_hidden_states,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
Reference in New Issue
Block a user