feat: add cache_salt support to request (#10718)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -228,6 +228,10 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -545,6 +549,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -778,6 +786,13 @@ class ResponsesRequest(BaseModel):
|
||||
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
||||
)
|
||||
priority: int = Field(default=0, description="Request priority")
|
||||
extra_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Extra key for classifying the request (e.g. cache_salt)",
|
||||
)
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None, description="Cache salt for request caching"
|
||||
)
|
||||
|
||||
# SGLang-specific sampling parameters
|
||||
frequency_penalty: float = 0.0
|
||||
|
||||
@@ -86,6 +86,19 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||
|
||||
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
|
||||
parts = []
|
||||
for key in ["cache_salt", "extra_key"]:
|
||||
value = getattr(request, key, None)
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Value of {key} must be a string, but got {type(value).__name__}"
|
||||
)
|
||||
parts.append(value)
|
||||
return "".join(parts) if parts else None
|
||||
|
||||
@abstractmethod
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
|
||||
@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
@@ -245,6 +245,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=request.stream,
|
||||
rid=request.request_id,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
background=request.background,
|
||||
)
|
||||
|
||||
@@ -1250,6 +1251,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=adapted_request.stream,
|
||||
rid=request_id,
|
||||
extra_key=adapted_request.extra_key,
|
||||
return_logprob=adapted_request.return_logprob,
|
||||
logprob_start_len=adapted_request.logprob_start_len,
|
||||
top_logprobs_num=adapted_request.top_logprobs_num,
|
||||
|
||||
@@ -84,6 +84,8 @@ class GenerateReqInput:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Whether to return logprobs.
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||
@@ -606,6 +608,9 @@ class TokenizedGenerateReqInput:
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[str] = None
|
||||
|
||||
# Image gen grpc migration
|
||||
return_bytes: bool = False
|
||||
|
||||
|
||||
@@ -491,7 +491,7 @@ class Req:
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||
# extra key for classifying the request (e.g. cache_salt)
|
||||
if lora_id is not None:
|
||||
extra_key = (
|
||||
extra_key or ""
|
||||
|
||||
@@ -750,6 +750,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
return_hidden_states=obj.return_hidden_states,
|
||||
data_parallel_rank=obj.data_parallel_rank,
|
||||
priority=obj.priority,
|
||||
extra_key=obj.extra_key,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
Reference in New Issue
Block a user