diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 5c75d888b..bc58f4ee5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -246,6 +246,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_hidden_states=recv_obj.output_hidden_states, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): @@ -257,6 +259,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): prompt_tokens=recv_obj.prompt_tokens, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) def handle_freeze_gc_req(self, recv_req: FreezeGCReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 753b2f828..06f3dfc99 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -121,6 +121,7 @@ class GenerateReqInput: bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None + bootstrap_pair_key: Optional[Union[List[str], str]] = None # For data parallel rank routing data_parallel_rank: Optional[int] = None @@ -128,6 +129,15 @@ class GenerateReqInput: # For background responses (OpenAI responses API) background: bool = False + # Conversation id used for tracking requests + conversation_id: Optional[str] = None + + # Label for the request + label: Optional[str] = None + + # Image gen grpc migration + return_bytes: bool = False + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -258,6 +268,7 @@ class GenerateReqInput: self._normalize_sampling_params(num) self._normalize_logprob_params(num) self._normalize_custom_logit_processor(num) + self._normalize_bootstrap_params(num) def _expand_inputs(self, num): """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling.""" @@ -297,6 +308,11 @@ class GenerateReqInput: self.image_data = [[self.image_data]] * num self.modalities = ["image"] * num elif isinstance(self.image_data, list): + # Handle empty list case - treat as no images + if len(self.image_data) == 0: + self.image_data = [None] * num + return + if len(self.image_data) != self.batch_size: raise ValueError( "The length of image_data should be equal to the batch size." @@ -421,6 +437,40 @@ class GenerateReqInput: "Cannot use list custom_logit_processor with parallel_sample_num > 1" ) + def _normalize_bootstrap_params(self, num): + """Normalize bootstrap parameters for batch processing.""" + # Normalize bootstrap_host + if self.bootstrap_host is None: + self.bootstrap_host = [None] * num + elif not isinstance(self.bootstrap_host, list): + self.bootstrap_host = [self.bootstrap_host] * num + elif isinstance(self.bootstrap_host, list): + self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num + + # Normalize bootstrap_port + if self.bootstrap_port is None: + self.bootstrap_port = [None] * num + elif not isinstance(self.bootstrap_port, list): + self.bootstrap_port = [self.bootstrap_port] * num + elif isinstance(self.bootstrap_port, list): + self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num + + # Normalize bootstrap_room + if self.bootstrap_room is None: + self.bootstrap_room = [None] * num + elif not isinstance(self.bootstrap_room, list): + self.bootstrap_room = [self.bootstrap_room + i for i in range(num)] + elif isinstance(self.bootstrap_room, list): + self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num + + # Normalize bootstrap_pair_key + if self.bootstrap_pair_key is None: + self.bootstrap_pair_key = [None] * num + elif not isinstance(self.bootstrap_pair_key, list): + self.bootstrap_pair_key = [self.bootstrap_pair_key] * num + elif isinstance(self.bootstrap_pair_key, list): + self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num + def _validate_session_params(self): """Validate that session parameters are properly formatted.""" if self.session_params is not None: @@ -453,7 +503,13 @@ class GenerateReqInput: return_text_in_logprobs=self.return_text_in_logprobs, stream=self.stream, log_metrics=self.log_metrics, + return_hidden_states=( + self.return_hidden_states[i] + if isinstance(self.return_hidden_states, list) + else self.return_hidden_states + ), modalities=self.modalities[i] if self.modalities else None, + session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, lora_id=self.lora_id[i] if self.lora_id is not None else None, custom_logit_processor=( @@ -461,11 +517,6 @@ class GenerateReqInput: if self.custom_logit_processor is not None else None ), - return_hidden_states=( - self.return_hidden_states[i] - if isinstance(self.return_hidden_states, list) - else self.return_hidden_states - ), # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list bootstrap_host=( self.bootstrap_host[i] if self.bootstrap_host is not None else None @@ -476,9 +527,17 @@ class GenerateReqInput: bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), + bootstrap_pair_key=( + self.bootstrap_pair_key[i] + if self.bootstrap_pair_key is not None + else None + ), data_parallel_rank=( self.data_parallel_rank if self.data_parallel_rank is not None else None ), + conversation_id=self.conversation_id, + label=self.label, + return_bytes=self.return_bytes, ) @@ -504,27 +563,28 @@ class TokenizedGenerateReqInput: token_ids_logprob: List[int] # Whether to stream output stream: bool + # Whether to return hidden states + return_hidden_states: bool = False - # LoRA related - lora_id: Optional[str] = None # None means just use the base model # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # Session info for continual prompting session_params: Optional[SessionParams] = None + # LoRA related + lora_id: Optional[str] = None # None means just use the base model + # Custom logit processor for advanced sampling control. Must be a serialized instance # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py # Use the processor's `to_str()` method to generate the serialized string. custom_logit_processor: Optional[str] = None - # Whether to return hidden states - return_hidden_states: bool = False - # For disaggregated inference bootstrap_host: Optional[str] = None bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + bootstrap_pair_key: Optional[str] = None # For data parallel rank routing data_parallel_rank: Optional[int] = None @@ -532,6 +592,12 @@ class TokenizedGenerateReqInput: # For dp balance dp_balance_id: int = -1 + # Label for the request + label: Optional[str] = None + + # Image gen grpc migration + return_bytes: bool = False + @dataclass class BatchTokenizedGenerateReqInput: @@ -738,9 +804,26 @@ class BatchTokenIDOut: # Hidden states output_hidden_states: List[List[float]] + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. + placeholder_tokens_idx: List[Optional[List[int]]] + placeholder_tokens_val: List[Optional[List[int]]] + @dataclass class BatchMultimodalDecodeReq: + decoded_ids: List[int] + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + read_offsets: List[int] + skip_special_tokens: List[bool] + spaces_between_special_tokens: List[bool] + image_resolutions: List[List[int]] + resize_image_resolutions: List[List[int]] + # The request id rids: List[str] finished_reasons: List[BaseFinishReason] @@ -750,6 +833,12 @@ class BatchMultimodalDecodeReq: completion_tokens: List[int] cached_tokens: List[int] + # Placeholder token info + placeholder_tokens_idx: List[Optional[List[int]]] + placeholder_tokens_val: List[Optional[List[int]]] + + return_bytes: bool = False + @dataclass class BatchStrOut: @@ -785,6 +874,9 @@ class BatchStrOut: # Hidden states output_hidden_states: List[List[float]] + placeholder_tokens_idx: List[Optional[List[int]]] + placeholder_tokens_val: List[Optional[List[int]]] + @dataclass class BatchMultimodalOut: @@ -792,14 +884,26 @@ class BatchMultimodalOut: rids: List[str] # The finish reason finished_reasons: List[dict] + decoded_ids: List[List[int]] # The outputs - outputs: List[List[Dict]] + outputs: Union[List[str | bytes], List[List[Dict]]] + + # probability values for input tokens and output tokens + input_token_logprobs_val: List[List[float]] + input_token_logprobs_idx: List[List[int]] + output_token_logprobs_val: List[List[float]] + output_token_logprobs_idx: List[List[int]] # Token counts prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] + placeholder_tokens_idx: List[Optional[List[int]]] + placeholder_tokens_val: List[Optional[List[int]]] + + return_bytes: List[bool] + @dataclass class BatchEmbeddingOut: @@ -812,6 +916,9 @@ class BatchEmbeddingOut: # Token counts prompt_tokens: List[int] cached_tokens: List[int] + # Placeholder token info + placeholder_tokens_idx: List[Optional[List[int]]] + placeholder_tokens_val: List[Optional[List[int]]] @dataclass @@ -844,6 +951,12 @@ class UpdateWeightFromDiskReqInput: abort_all_requests: bool = False # Optional: Update weight version along with weights weight_version: Optional[str] = None + # Whether to update weights asynchronously + is_async: bool = False + # Whether to empty torch cache + torch_empty_cache: bool = False + # Whether to keep the scheduler paused after weight update + keep_pause: bool = False @dataclass @@ -983,6 +1096,7 @@ class AbortReq: abort_all: bool = False # The finished reason data finished_reason: Optional[Dict[str, Any]] = None + abort_reason: Optional[str] = None # used in MultiTokenzierManager mode rids: Optional[Union[List[str], str]] = None @@ -1061,6 +1175,7 @@ class ConfigureLoggingReq: log_requests_level: Optional[int] = None dump_requests_folder: Optional[str] = None dump_requests_threshold: Optional[int] = None + crash_dump_folder: Optional[str] = None @dataclass diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 4ab2e6a6f..0aadfba2c 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -195,6 +195,8 @@ def _handle_output_by_index(output, i): if output.output_hidden_states else None ), + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) elif isinstance(output, BatchEmbeddingOut): new_output = BatchEmbeddingOut( @@ -211,6 +213,8 @@ def _handle_output_by_index(output, i): cached_tokens=( [output.cached_tokens[i]] if len(output.cached_tokens) > i else None ), + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) elif isinstance(output, BatchStrOut): new_output = BatchStrOut( @@ -307,6 +311,8 @@ def _handle_output_by_index(output, i): if output.output_hidden_states else None ), + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) elif isinstance(output, BatchMultimodalOut): new_output = BatchMultimodalOut( @@ -328,6 +334,8 @@ def _handle_output_by_index(output, i): cached_tokens=( [output.cached_tokens[i]] if len(output.cached_tokens) > i else None ), + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) else: new_output = output diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index c6205a094..d931759bb 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -700,6 +700,8 @@ class SchedulerOutputProcessorMixin: output_token_ids_logprobs_val, output_token_ids_logprobs_idx, output_hidden_states, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) ) @@ -719,6 +721,12 @@ class SchedulerOutputProcessorMixin: cached_tokens.append(req.cached_tokens) self.send_to_detokenizer.send_pyobj( BatchEmbeddingOut( - rids, finished_reasons, embeddings, prompt_tokens, cached_tokens + rids, + finished_reasons, + embeddings, + prompt_tokens, + cached_tokens, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, ) )