[Feat] Return hidden states (experimental) (#3364)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -315,6 +315,7 @@ class Req:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
@@ -604,6 +605,9 @@ class ScheduleBatch:
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -615,6 +619,7 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -629,6 +634,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1196,9 +1202,15 @@ class ScheduleBatch:
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
||||
if self.spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
CaptureHiddenMode.FULL
|
||||
if self.return_hidden_states
|
||||
else (
|
||||
getattr(
|
||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
if self.spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user