diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 0ff49f135..ec6ce7f99 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -228,6 +228,10 @@ class CompletionRequest(BaseModel): # For request id rid: Optional[Union[List[str], str]] = None + # Extra key for classifying the request (e.g. cache_salt) + extra_key: Optional[Union[List[str], str]] = None + # Cache salt for request caching + cache_salt: Optional[Union[List[str], str]] = None # Priority for the request priority: Optional[int] = None @@ -545,6 +549,10 @@ class ChatCompletionRequest(BaseModel): # For request id rid: Optional[Union[List[str], str]] = None + # Extra key for classifying the request (e.g. cache_salt) + extra_key: Optional[Union[List[str], str]] = None + # Cache salt for request caching + cache_salt: Optional[Union[List[str], str]] = None # Priority for the request priority: Optional[int] = None @@ -778,6 +786,13 @@ class ResponsesRequest(BaseModel): description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.", ) priority: int = Field(default=0, description="Request priority") + extra_key: Optional[str] = Field( + default=None, + description="Extra key for classifying the request (e.g. cache_salt)", + ) + cache_salt: Optional[str] = Field( + default=None, description="Cache salt for request caching" + ) # SGLang-specific sampling parameters frequency_penalty: float = 0.0 diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 5bc505108..f0038bdea 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -86,6 +86,19 @@ class OpenAIServingBase(ABC): return f"{self._request_id_prefix()}{uuid.uuid4().hex}" + def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]: + """Compute the final extra_key by concatenating cache_salt and extra_key if both are provided.""" + parts = [] + for key in ["cache_salt", "extra_key"]: + value = getattr(request, key, None) + if value: + if not isinstance(value, str): + raise TypeError( + f"Value of {key} must be a string, but got {type(value).__name__}" + ) + parts.append(value) + return "".join(parts) if parts else None + @abstractmethod def _convert_to_internal_request( self, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 8bd57fc9e..df40ebead 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + extra_key=self._compute_extra_key(request), priority=request.priority, customer_labels=customer_labels, ) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 6aa4fe19e..e394b733b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): bootstrap_room=request.bootstrap_room, return_hidden_states=request.return_hidden_states, rid=request.rid, + extra_key=self._compute_extra_key(request), priority=request.priority, customer_labels=customer_labels, ) diff --git a/python/sglang/srt/entrypoints/openai/serving_responses.py b/python/sglang/srt/entrypoints/openai/serving_responses.py index 3f7619678..5e965e3bb 100644 --- a/python/sglang/srt/entrypoints/openai/serving_responses.py +++ b/python/sglang/srt/entrypoints/openai/serving_responses.py @@ -245,6 +245,7 @@ class OpenAIServingResponses(OpenAIServingChat): sampling_params=sampling_params, stream=request.stream, rid=request.request_id, + extra_key=self._compute_extra_key(request), background=request.background, ) @@ -1250,6 +1251,7 @@ class OpenAIServingResponses(OpenAIServingChat): sampling_params=sampling_params, stream=adapted_request.stream, rid=request_id, + extra_key=adapted_request.extra_key, return_logprob=adapted_request.return_logprob, logprob_start_len=adapted_request.logprob_start_len, top_logprobs_num=adapted_request.top_logprobs_num, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 86cfcf945..5d42fde0d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -84,6 +84,8 @@ class GenerateReqInput: sampling_params: Optional[Union[List[Dict], Dict]] = None # The request id. rid: Optional[Union[List[str], str]] = None + # Extra key for classifying the request (e.g. cache_salt) + extra_key: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None # If return logprobs, the start location in the prompt for returning logprobs. @@ -606,6 +608,9 @@ class TokenizedGenerateReqInput: # Priority for the request priority: Optional[int] = None + # Extra key for classifying the request (e.g. cache_salt) + extra_key: Optional[str] = None + # Image gen grpc migration return_bytes: bool = False diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f46e160cd..3a3a6b06b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -491,7 +491,7 @@ class Req: self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states - # extra key for classifying the request (e.g. lora_id, cache_salt) + # extra key for classifying the request (e.g. cache_salt) if lora_id is not None: extra_key = ( extra_key or "" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3d901ceb5..dd341aa3a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -750,6 +750,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): return_hidden_states=obj.return_hidden_states, data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, + extra_key=obj.extra_key, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/test/srt/openai_server/features/test_cache_report.py b/test/srt/openai_server/features/test_cache_report.py index 999111a2e..939556993 100644 --- a/test/srt/openai_server/features/test_cache_report.py +++ b/test/srt/openai_server/features/test_cache_report.py @@ -207,6 +207,84 @@ class TestCacheReport(CustomTestCase): # asyncio.run(run_test()) + def test_cache_salt_effectiveness(self): + print("=" * 100) + print("Testing cache_salt effectiveness") + + # Use a unique message to avoid interference with other tests + test_message = "What is the capital of Japan?" + + # First request with cache_salt "salt1" + response1 = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": test_message}], + temperature=0, + max_tokens=10, + extra_body={"cache_salt": "salt1"}, + ) + cached_tokens_1_first = int(response1.usage.prompt_tokens_details.cached_tokens) + prompt_tokens_1 = int(response1.usage.prompt_tokens) + print( + f"First request with salt1 - cached_tokens: {cached_tokens_1_first}, prompt_tokens: {prompt_tokens_1}" + ) + + # Second request with same cache_salt "salt1" - should get cache hit + response2 = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": test_message}], + temperature=0, + max_tokens=10, + extra_body={"cache_salt": "salt1"}, + ) + cached_tokens_1_second = int( + response2.usage.prompt_tokens_details.cached_tokens + ) + print( + f"Second request with salt1 - cached_tokens: {cached_tokens_1_second}, prompt_tokens: {prompt_tokens_1}" + ) + + # Verify cache hit for same salt + assert ( + cached_tokens_1_second > cached_tokens_1_first + ), "Should have cache hit with same cache_salt" + assert ( + cached_tokens_1_second == prompt_tokens_1 - 1 + ), "Should cache all prompt tokens except the last one" + + # Third request with different cache_salt "salt2" - should not get cache hit + response3 = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": test_message}], + temperature=0, + max_tokens=10, + extra_body={"cache_salt": "salt2"}, + ) + cached_tokens_2_first = int(response3.usage.prompt_tokens_details.cached_tokens) + print(f"First request with salt2 - cached_tokens: {cached_tokens_2_first}") + + # Verify no cache hit for different salt (should be similar to first request with salt1) + assert ( + cached_tokens_2_first <= cached_tokens_1_first + self.min_cached + ), "Different cache_salt should not share cache" + + # Fourth request with same cache_salt "salt2" - should now get cache hit + response4 = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": test_message}], + temperature=0, + max_tokens=10, + extra_body={"cache_salt": "salt2"}, + ) + cached_tokens_2_second = int( + response4.usage.prompt_tokens_details.cached_tokens + ) + print(f"Second request with salt2 - cached_tokens: {cached_tokens_2_second}") + + # Verify cache hit for salt2 + assert ( + cached_tokens_2_second == cached_tokens_2_first + ), "Should have cache hit with same cache_salt for salt2" + if __name__ == "__main__": unittest.main()