Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -37,9 +37,6 @@ class SamplingBatchInfo:
# Whether any request has custom logit processor
has_custom_logit_processor: bool
# Whether any request needs to return hidden states
return_hidden_states: bool
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
@@ -94,9 +91,6 @@ class SamplingBatchInfo:
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
# Check if any request needs to return hidden states
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
processor_dict = {}
@@ -136,7 +130,6 @@ class SamplingBatchInfo:
device=device,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -344,9 +337,6 @@ class SamplingBatchInfo:
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
# Merge the return hidden states flag
self.return_hidden_states |= other.return_hidden_states
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors

View File

@@ -49,7 +49,6 @@ class SamplingParams:
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
return_hidden_states: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.temperature = temperature
@@ -75,7 +74,6 @@ class SamplingParams:
self.ebnf = ebnf
self.structural_tag = structural_tag
self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states
self.custom_params = custom_params
# Process some special cases