Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -236,6 +236,7 @@ class Req:
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
custom_logit_processor: Optional[str] = None,
|
||||
return_hidden_states: bool = False,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -256,7 +257,9 @@ class Req:
|
||||
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
@@ -608,6 +611,9 @@ class ScheduleBatch:
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -619,6 +625,7 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -633,6 +640,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1153,6 +1161,7 @@ class ScheduleBatch:
|
||||
self.return_logprob |= other.return_logprob
|
||||
self.has_stream |= other.has_stream
|
||||
self.has_grammar |= other.has_grammar
|
||||
self.return_hidden_states |= other.return_hidden_states
|
||||
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
@@ -1201,7 +1210,7 @@ class ScheduleBatch:
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
CaptureHiddenMode.FULL
|
||||
if self.sampling_info.return_hidden_states
|
||||
if self.return_hidden_states
|
||||
else (
|
||||
getattr(
|
||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
|
||||
Reference in New Issue
Block a user