Refactors radix cache for extra key support (#10317)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
Xinyuan Tong
2025-09-21 11:16:16 -07:00
committed by GitHub
parent fc3e542009
commit 12d6cf18f0
13 changed files with 821 additions and 574 deletions

View File

@@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import (
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -457,6 +457,7 @@ class Req:
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None,
):
# Input and output info
self.rid = rid
@@ -489,6 +490,14 @@ class Req:
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
) + lora_id # lora_id is concatenated to the extra key
self.extra_key = extra_key
self.lora_id = lora_id
# Memory pool info
@@ -679,26 +688,16 @@ class Req:
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
if isinstance(tree_cache, LoRARadixCache):
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix_with_lora_id(
key=LoRAKey(
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
),
)
else:
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(),
)
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
),
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):