Add return hidden state in the native API (#3897)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Qiaolin Yu
2025-02-27 01:06:54 -05:00
committed by GitHub
parent 71ed01833d
commit d6898dd253
9 changed files with 112 additions and 34 deletions

View File

@@ -37,6 +37,9 @@ class SamplingBatchInfo:
# Whether any request has custom logit processor
has_custom_logit_processor: bool
# Whether any request needs to return hidden states
return_hidden_states: bool
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
@@ -91,6 +94,9 @@ class SamplingBatchInfo:
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
# Check if any request needs to return hidden states
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
processor_dict = {}
@@ -130,6 +136,7 @@ class SamplingBatchInfo:
device=device,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
return_hidden_states=return_hidden_states,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -336,6 +343,10 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
# Merge the return hidden states flag
self.return_hidden_states |= other.return_hidden_states
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors

View File

@@ -48,6 +48,7 @@ class SamplingParams:
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
return_hidden_states: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.temperature = temperature
@@ -72,6 +73,7 @@ class SamplingParams:
self.json_schema = json_schema
self.ebnf = ebnf
self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states
self.custom_params = custom_params
# Process some special cases