Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -37,9 +37,6 @@ class SamplingBatchInfo:
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool
|
||||
|
||||
# Whether any request needs to return hidden states
|
||||
return_hidden_states: bool
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
@@ -94,9 +91,6 @@ class SamplingBatchInfo:
|
||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||
)
|
||||
|
||||
# Check if any request needs to return hidden states
|
||||
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
processor_dict = {}
|
||||
@@ -136,7 +130,6 @@ class SamplingBatchInfo:
|
||||
device=device,
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -344,9 +337,6 @@ class SamplingBatchInfo:
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
# Merge the return hidden states flag
|
||||
self.return_hidden_states |= other.return_hidden_states
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
|
||||
@@ -49,7 +49,6 @@ class SamplingParams:
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
return_hidden_states: bool = False,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
@@ -75,7 +74,6 @@ class SamplingParams:
|
||||
self.ebnf = ebnf
|
||||
self.structural_tag = structural_tag
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.custom_params = custom_params
|
||||
|
||||
# Process some special cases
|
||||
|
||||
Reference in New Issue
Block a user