[OAI] patch origin request_id logic (#7508)
This commit is contained in:
@@ -196,6 +196,9 @@ class CompletionRequest(BaseModel):
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens_positive(cls, v):
|
||||
@@ -430,8 +433,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_reasoning: bool = True
|
||||
chat_template_kwargs: Optional[Dict] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
@@ -529,7 +532,7 @@ class EmbeddingRequest(BaseModel):
|
||||
user: Optional[str] = None
|
||||
|
||||
# The request id.
|
||||
rid: Optional[str] = None
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
|
||||
@@ -95,6 +95,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
|
||||
adapted_request = EmbeddingReqInput(
|
||||
**prompt_kwargs,
|
||||
rid=request.rid,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -319,8 +319,16 @@ class GenerateReqInput:
|
||||
"""Normalize request IDs for batch processing."""
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
elif not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list for batch processing.")
|
||||
elif isinstance(self.rid, str):
|
||||
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
||||
self.rid = new_rids
|
||||
elif isinstance(self.rid, list):
|
||||
if len(self.rid) != num:
|
||||
raise ValueError(
|
||||
"The specified rids length mismatch with the batch_size for batch processing."
|
||||
)
|
||||
else:
|
||||
raise ValueError("The rid should be a string or a list of strings.")
|
||||
|
||||
def _normalize_logprob_params(self, num):
|
||||
"""Normalize logprob-related parameters for batch processing."""
|
||||
|
||||
Reference in New Issue
Block a user