feat: add cache_salt support to request (#10718)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -228,6 +228,10 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -545,6 +549,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -778,6 +786,13 @@ class ResponsesRequest(BaseModel):
|
||||
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
||||
)
|
||||
priority: int = Field(default=0, description="Request priority")
|
||||
extra_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Extra key for classifying the request (e.g. cache_salt)",
|
||||
)
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None, description="Cache salt for request caching"
|
||||
)
|
||||
|
||||
# SGLang-specific sampling parameters
|
||||
frequency_penalty: float = 0.0
|
||||
|
||||
@@ -86,6 +86,19 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||
|
||||
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
|
||||
parts = []
|
||||
for key in ["cache_salt", "extra_key"]:
|
||||
value = getattr(request, key, None)
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Value of {key} must be a string, but got {type(value).__name__}"
|
||||
)
|
||||
parts.append(value)
|
||||
return "".join(parts) if parts else None
|
||||
|
||||
@abstractmethod
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
|
||||
@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
@@ -245,6 +245,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=request.stream,
|
||||
rid=request.request_id,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
background=request.background,
|
||||
)
|
||||
|
||||
@@ -1250,6 +1251,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=adapted_request.stream,
|
||||
rid=request_id,
|
||||
extra_key=adapted_request.extra_key,
|
||||
return_logprob=adapted_request.return_logprob,
|
||||
logprob_start_len=adapted_request.logprob_start_len,
|
||||
top_logprobs_num=adapted_request.top_logprobs_num,
|
||||
|
||||
@@ -84,6 +84,8 @@ class GenerateReqInput:
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Whether to return logprobs.
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||
@@ -606,6 +608,9 @@ class TokenizedGenerateReqInput:
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[str] = None
|
||||
|
||||
# Image gen grpc migration
|
||||
return_bytes: bool = False
|
||||
|
||||
|
||||
@@ -491,7 +491,7 @@ class Req:
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||
# extra key for classifying the request (e.g. cache_salt)
|
||||
if lora_id is not None:
|
||||
extra_key = (
|
||||
extra_key or ""
|
||||
|
||||
@@ -750,6 +750,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
return_hidden_states=obj.return_hidden_states,
|
||||
data_parallel_rank=obj.data_parallel_rank,
|
||||
priority=obj.priority,
|
||||
extra_key=obj.extra_key,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
@@ -207,6 +207,84 @@ class TestCacheReport(CustomTestCase):
|
||||
|
||||
# asyncio.run(run_test())
|
||||
|
||||
def test_cache_salt_effectiveness(self):
|
||||
print("=" * 100)
|
||||
print("Testing cache_salt effectiveness")
|
||||
|
||||
# Use a unique message to avoid interference with other tests
|
||||
test_message = "What is the capital of Japan?"
|
||||
|
||||
# First request with cache_salt "salt1"
|
||||
response1 = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": test_message}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
extra_body={"cache_salt": "salt1"},
|
||||
)
|
||||
cached_tokens_1_first = int(response1.usage.prompt_tokens_details.cached_tokens)
|
||||
prompt_tokens_1 = int(response1.usage.prompt_tokens)
|
||||
print(
|
||||
f"First request with salt1 - cached_tokens: {cached_tokens_1_first}, prompt_tokens: {prompt_tokens_1}"
|
||||
)
|
||||
|
||||
# Second request with same cache_salt "salt1" - should get cache hit
|
||||
response2 = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": test_message}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
extra_body={"cache_salt": "salt1"},
|
||||
)
|
||||
cached_tokens_1_second = int(
|
||||
response2.usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
print(
|
||||
f"Second request with salt1 - cached_tokens: {cached_tokens_1_second}, prompt_tokens: {prompt_tokens_1}"
|
||||
)
|
||||
|
||||
# Verify cache hit for same salt
|
||||
assert (
|
||||
cached_tokens_1_second > cached_tokens_1_first
|
||||
), "Should have cache hit with same cache_salt"
|
||||
assert (
|
||||
cached_tokens_1_second == prompt_tokens_1 - 1
|
||||
), "Should cache all prompt tokens except the last one"
|
||||
|
||||
# Third request with different cache_salt "salt2" - should not get cache hit
|
||||
response3 = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": test_message}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
extra_body={"cache_salt": "salt2"},
|
||||
)
|
||||
cached_tokens_2_first = int(response3.usage.prompt_tokens_details.cached_tokens)
|
||||
print(f"First request with salt2 - cached_tokens: {cached_tokens_2_first}")
|
||||
|
||||
# Verify no cache hit for different salt (should be similar to first request with salt1)
|
||||
assert (
|
||||
cached_tokens_2_first <= cached_tokens_1_first + self.min_cached
|
||||
), "Different cache_salt should not share cache"
|
||||
|
||||
# Fourth request with same cache_salt "salt2" - should now get cache hit
|
||||
response4 = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": test_message}],
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
extra_body={"cache_salt": "salt2"},
|
||||
)
|
||||
cached_tokens_2_second = int(
|
||||
response4.usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
print(f"Second request with salt2 - cached_tokens: {cached_tokens_2_second}")
|
||||
|
||||
# Verify cache hit for salt2
|
||||
assert (
|
||||
cached_tokens_2_second == cached_tokens_2_first
|
||||
), "Should have cache hit with same cache_salt for salt2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user