diff --git a/python/sglang/srt/constrained/base_cache.py b/python/sglang/srt/constrained/base_tool_cache.py similarity index 96% rename from python/sglang/srt/constrained/base_cache.py rename to python/sglang/srt/constrained/base_tool_cache.py index 7d6b20469..4cbb6bd22 100644 --- a/python/sglang/srt/constrained/base_cache.py +++ b/python/sglang/srt/constrained/base_tool_cache.py @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Base cache class.""" +"""Base tool cache for constrained decoding tools.""" import time -class BaseCache: +class BaseToolCache: def __init__(self, enable=True): self.enable = enable self.reset() diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index db0344275..6df6bec51 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -16,10 +16,10 @@ limitations under the License. """Cache for the compressed finite state machine.""" from sglang.srt.constrained import RegexGuide, TransformerTokenizer -from sglang.srt.constrained.base_cache import BaseCache +from sglang.srt.constrained.base_tool_cache import BaseToolCache -class FSMCache(BaseCache): +class FSMCache(BaseToolCache): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): super().__init__(enable=enable) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 30c316a84..7b694318e 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -30,7 +30,7 @@ from sglang.srt.constrained import ( make_byte_level_fsm, make_deterministic_fsm, ) -from sglang.srt.constrained.base_cache import BaseCache +from sglang.srt.constrained.base_tool_cache import BaseToolCache IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" @@ -151,7 +151,7 @@ class JumpForwardMap: ) -class JumpForwardCache(BaseCache): +class JumpForwardCache(BaseToolCache): def __init__(self): super().__init__() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 157cfd778..88d33db5a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.radix_cache import RadixCache @@ -486,15 +487,33 @@ class Batch: req = self.reqs[idx] retracted_reqs.append(req) - # TODO: apply more fine-grained retraction - last_uncached_pos = len(req.prefix_indices) - token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[idx] - ][last_uncached_pos : seq_lens_cpu[idx]] - self.token_to_kv_pool.free(token_indices) + if isinstance(self.tree_cache, ChunkCache): + # ChunkCache does not have eviction + token_indices = self.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[idx] + ][: seq_lens_cpu[idx]] + self.token_to_kv_pool.free(token_indices) + self.req_to_token_pool.free(int(req_pool_indices_cpu[idx])) + del self.tree_cache.entries[req.rid] + else: + # TODO: apply more fine-grained retraction + last_uncached_pos = len(req.prefix_indices) + token_indices = self.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[idx] + ][last_uncached_pos : seq_lens_cpu[idx]] + self.token_to_kv_pool.free(token_indices) + self.req_to_token_pool.free(int(req_pool_indices_cpu[idx])) - # release the last node - self.tree_cache.dec_lock_ref(req.last_node) + # release the last node + self.tree_cache.dec_lock_ref(req.last_node) + + # NOTE(lsyin): we should use the newly evictable memory instantly. + residual_size = ( + len(sorted_indices) * global_config.retract_decode_steps + - self.token_to_kv_pool.available_size() + ) + residual_size = max(0, residual_size) + self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) req.prefix_indices = None req.last_node = None @@ -575,6 +594,7 @@ class Batch: if req_pool_indices_cpu is None: req_pool_indices_cpu = self.req_pool_indices.tolist() self.tree_cache.cache_req( + rid=req.rid, token_ids=cur_all_ids, last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 1f024501e..33acf98e8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import ( ForwardMode, Req, ) +from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.model_runner import ModelRunner @@ -144,11 +145,20 @@ class ModelTpServer: ) # Init cache - self.tree_cache = RadixCache( - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - disable=server_args.disable_radix_cache, - ) + if ( + server_args.chunked_prefill_size is not None + and server_args.disable_radix_cache + ): + self.tree_cache = ChunkCache( + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + ) + else: + self.tree_cache = RadixCache( + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + disable=server_args.disable_radix_cache, + ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler( self.schedule_policy, @@ -354,7 +364,10 @@ class ModelTpServer: # Compute matched prefix length for req in self.waiting_queue: req.input_ids = req.origin_input_ids + req.output_ids - prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) + prefix_indices, last_node = self.tree_cache.match_prefix( + rid=req.rid, + key=req.input_ids, + ) if req.return_logprob: prefix_indices = prefix_indices[: req.logprob_start_len] req.extend_input_len = len(req.input_ids) - len(prefix_indices) @@ -614,6 +627,7 @@ class ModelTpServer: req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( + rid=req.rid, token_ids=tuple(req.input_ids), last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], @@ -771,6 +785,7 @@ class ModelTpServer: for i in finished_indices: req = batch.reqs[i] self.tree_cache.cache_req( + rid=req.rid, token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], diff --git a/python/sglang/srt/mem_cache/base_cache.py b/python/sglang/srt/mem_cache/base_cache.py new file mode 100644 index 000000000..fe7e0b23a --- /dev/null +++ b/python/sglang/srt/mem_cache/base_cache.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod + + +class BasePrefixCache(ABC): + """Cache can be indexed by either rid or key.""" + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def match_prefix(self, **kwargs): + pass + + @abstractmethod + def insert(self, **kwargs): + pass + + @abstractmethod + def cache_req(self, **kwargs): + pass + + @abstractmethod + def evict(self, num_tokens, evict_callback): + pass + + @abstractmethod + def inc_lock_ref(self, node): + pass + + @abstractmethod + def dec_lock_ref(self, node): + pass + + @abstractmethod + def evictable_size(self): + pass + + def total_size(self): + raise NotImplementedError + + def pretty_print(self): + raise NotImplementedError diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py new file mode 100644 index 000000000..3509bd1cd --- /dev/null +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -0,0 +1,60 @@ +"""Cache for chunked prefill, used when RadixCache is disabled.""" + +from sglang.srt.mem_cache.base_cache import BasePrefixCache + + +class ChunkCacheEntry: + def __init__(self, rid, value): + self.rid = rid + self.value = value + + +class ChunkCache(BasePrefixCache): + def __init__(self, req_to_token_pool, token_to_kv_pool): + self.disable = True + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool = token_to_kv_pool + + self.reset() + + def reset(self): + self.entries = {} + + def match_prefix(self, rid, **kwargs): + if rid not in self.entries: + return [], None + + entry = self.entries[rid] + return entry.value, entry + + def cache_req( + self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs + ): + indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] + if del_in_memory_pool: + assert rid in self.entries + self.req_to_token_pool.free(req_pool_idx) + self.token_to_kv_pool.free(indices) + return + + if rid not in self.entries: + self.entries[rid] = ChunkCacheEntry(rid, indices) + + entry = self.entries[rid] + entry.value = indices + return indices, entry + + def insert(self): + raise NotImplementedError + + def evict(self, num_tokens, evict_callback): + pass + + def inc_lock_ref(self, node): + return 0 + + def dec_lock_ref(self, node): + return 0 + + def evictable_size(self): + return 0 diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cff4e114..c6fc3191b 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -23,6 +23,8 @@ from collections import defaultdict import torch +from sglang.srt.mem_cache.base_cache import BasePrefixCache + class TreeNode: def __init__(self): @@ -46,7 +48,7 @@ def _key_match(key0, key1): return i -class RadixCache: +class RadixCache(BasePrefixCache): def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool @@ -62,7 +64,7 @@ class RadixCache: self.root_node.lock_ref = 1 self.evictable_size_ = 0 - def match_prefix(self, key): + def match_prefix(self, key, **kwargs): if self.disable: return [], self.root_node @@ -90,6 +92,7 @@ class RadixCache: req_pool_idx, del_in_memory_pool=True, old_last_node=None, + **kwargs, ): # Insert the request into radix cache indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ab4a350cf..794cc6993 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -419,10 +419,6 @@ class ServerArgs: self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" - assert not ( - self.chunked_prefill_size is not None and self.disable_radix_cache - ), "chunked prefill is not supported with radix cache disabled currently" - @dataclasses.dataclass class PortArgs: