Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -236,6 +236,7 @@ class Req:
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
@@ -256,7 +257,9 @@ class Req:
# Sampling info
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# Memory pool info
self.req_pool_idx = None
@@ -608,6 +611,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Whether to return hidden states
return_hidden_states: bool = False
@classmethod
def init_new(
cls,
@@ -619,6 +625,7 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -633,6 +640,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
def batch_size(self):
@@ -1153,6 +1161,7 @@ class ScheduleBatch:
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar
self.return_hidden_states |= other.return_hidden_states
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
@@ -1201,7 +1210,7 @@ class ScheduleBatch:
spec_info=self.spec_info,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.sampling_info.return_hidden_states
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL