[Feat] Return hidden states (experimental) (#3364)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Jackmin801
2025-02-10 15:54:37 -08:00
committed by GitHub
parent 2f47d710ae
commit 5f0e7de339
12 changed files with 204 additions and 5 deletions

View File

@@ -315,6 +315,7 @@ class Req:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
@@ -604,6 +605,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Return hidden states
return_hidden_states: bool = False
@classmethod
def init_new(
cls,
@@ -615,6 +619,7 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -629,6 +634,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
def batch_size(self):
@@ -1196,9 +1202,15 @@ class ScheduleBatch:
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if self.spec_info
else CaptureHiddenMode.NULL
)
),
)