Add return hidden state in the native API (#3897)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Qiaolin Yu
2025-02-27 01:06:54 -05:00
committed by GitHub
parent 71ed01833d
commit d6898dd253
9 changed files with 112 additions and 34 deletions

View File

@@ -607,9 +607,6 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Return hidden states
return_hidden_states: bool = False
@classmethod
def init_new(
cls,
@@ -621,7 +618,6 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -636,7 +632,6 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
def batch_size(self):
@@ -1205,7 +1200,7 @@ class ScheduleBatch:
spec_info=self.spec_info,
capture_hidden_mode=(
CaptureHiddenMode.FULL
if self.return_hidden_states
if self.sampling_info.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL