Add return hidden state in the native API (#3897)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool
|
||||
|
||||
# Whether any request needs to return hidden states
|
||||
return_hidden_states: bool
|
||||
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
@@ -91,6 +94,9 @@ class SamplingBatchInfo:
|
||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||
)
|
||||
|
||||
# Check if any request needs to return hidden states
|
||||
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
processor_dict = {}
|
||||
@@ -130,6 +136,7 @@ class SamplingBatchInfo:
|
||||
device=device,
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -336,6 +343,10 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
# Merge the return hidden states flag
|
||||
self.return_hidden_states |= other.return_hidden_states
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
|
||||
@@ -48,6 +48,7 @@ class SamplingParams:
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
return_hidden_states: bool = False,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
@@ -72,6 +73,7 @@ class SamplingParams:
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.custom_params = custom_params
|
||||
|
||||
# Process some special cases
|
||||
|
||||
Reference in New Issue
Block a user