Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -69,11 +69,15 @@ class GenerateReqInput:
|
||||
|
||||
# Session info for continual prompting
|
||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||
|
||||
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||
# Use the processor's `to_str()` method to generate the serialized string.
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (
|
||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||
@@ -218,6 +222,7 @@ class GenerateReqInput:
|
||||
if self.custom_logit_processor is not None
|
||||
else None
|
||||
),
|
||||
return_hidden_states=self.return_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@@ -255,6 +260,9 @@ class TokenizedGenerateReqInput:
|
||||
# Use the processor's `to_str()` method to generate the serialized string.
|
||||
custom_logit_processor: Optional[str] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -236,6 +236,7 @@ class Req:
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
custom_logit_processor: Optional[str] = None,
|
||||
return_hidden_states: bool = False,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -256,7 +257,9 @@ class Req:
|
||||
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
@@ -608,6 +611,9 @@ class ScheduleBatch:
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -619,6 +625,7 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -633,6 +640,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1153,6 +1161,7 @@ class ScheduleBatch:
|
||||
self.return_logprob |= other.return_logprob
|
||||
self.has_stream |= other.has_stream
|
||||
self.has_grammar |= other.has_grammar
|
||||
self.return_hidden_states |= other.return_hidden_states
|
||||
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
@@ -1201,7 +1210,7 @@ class ScheduleBatch:
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
CaptureHiddenMode.FULL
|
||||
if self.sampling_info.return_hidden_states
|
||||
if self.return_hidden_states
|
||||
else (
|
||||
getattr(
|
||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
|
||||
@@ -631,6 +631,7 @@ class Scheduler:
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
@@ -947,9 +948,11 @@ class Scheduler:
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
|
||||
return_hidden_states = False
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if req.return_hidden_states:
|
||||
return_hidden_states = True
|
||||
if (
|
||||
self.lora_paths
|
||||
and len(
|
||||
@@ -1035,6 +1038,7 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
return_hidden_states,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1226,7 +1230,7 @@ class Scheduler:
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
if (
|
||||
req.sampling_params.return_hidden_states
|
||||
req.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
req.hidden_states.append(
|
||||
@@ -1333,10 +1337,7 @@ class Scheduler:
|
||||
logits_output.next_token_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if (
|
||||
req.sampling_params.return_hidden_states
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
if req.return_hidden_states and logits_output.hidden_states is not None:
|
||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||
|
||||
if req.grammar is not None:
|
||||
@@ -1462,10 +1463,7 @@ class Scheduler:
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
return_hidden_states = any(
|
||||
req.sampling_params.return_hidden_states for req in reqs
|
||||
)
|
||||
output_hidden_states = [] if return_hidden_states else None
|
||||
output_hidden_states = None
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1532,7 +1530,9 @@ class Scheduler:
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
|
||||
if req.sampling_params.return_hidden_states:
|
||||
if req.return_hidden_states:
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = []
|
||||
output_hidden_states.append(req.hidden_states)
|
||||
|
||||
# Send to detokenizer
|
||||
|
||||
@@ -383,6 +383,7 @@ class TokenizerManager:
|
||||
input_embeds=input_embeds,
|
||||
session_params=session_params,
|
||||
custom_logit_processor=obj.custom_logit_processor,
|
||||
return_hidden_states=obj.return_hidden_states,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
Reference in New Issue
Block a user