diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 6bc975c04..b44f338e5 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -196,6 +196,9 @@ class CompletionRequest(BaseModel): bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + # For request id + rid: Optional[Union[List[str], str]] = None + @field_validator("max_tokens") @classmethod def validate_max_tokens_positive(cls, v): @@ -430,8 +433,8 @@ class ChatCompletionRequest(BaseModel): stream_reasoning: bool = True chat_template_kwargs: Optional[Dict] = None - # The request id. - rid: Optional[str] = None + # For request id + rid: Optional[Union[List[str], str]] = None # For PD disaggregation bootstrap_host: Optional[str] = None @@ -529,7 +532,7 @@ class EmbeddingRequest(BaseModel): user: Optional[str] = None # The request id. - rid: Optional[str] = None + rid: Optional[Union[List[str], str]] = None class EmbeddingObject(BaseModel): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9857bcd9e..b7d974e47 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -95,6 +95,7 @@ class OpenAIServingChat(OpenAIServingBase): bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, + rid=request.rid, ) return adapted_request, request diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 3db881641..992787132 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase): bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, + rid=request.rid, ) return adapted_request, request diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 4f2db1dbe..b9ac4559f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): adapted_request = EmbeddingReqInput( **prompt_kwargs, + rid=request.rid, ) return adapted_request, request diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 0fcb38227..9451827a9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -319,8 +319,16 @@ class GenerateReqInput: """Normalize request IDs for batch processing.""" if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(num)] - elif not isinstance(self.rid, list): - raise ValueError("The rid should be a list for batch processing.") + elif isinstance(self.rid, str): + new_rids = [f"{self.rid}_{i}" for i in range(num)] + self.rid = new_rids + elif isinstance(self.rid, list): + if len(self.rid) != num: + raise ValueError( + "The specified rids length mismatch with the batch_size for batch processing." + ) + else: + raise ValueError("The rid should be a string or a list of strings.") def _normalize_logprob_params(self, num): """Normalize logprob-related parameters for batch processing."""