Support chunked prefill when radix cache is disabled (#811)
This commit is contained in:
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Base cache class."""
|
||||
"""Base tool cache for constrained decoding tools."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class BaseCache:
|
||||
class BaseToolCache:
|
||||
def __init__(self, enable=True):
|
||||
self.enable = enable
|
||||
self.reset()
|
||||
@@ -16,10 +16,10 @@ limitations under the License.
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
|
||||
class FSMCache(BaseCache):
|
||||
class FSMCache(BaseToolCache):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
|
||||
make_byte_level_fsm,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
@@ -151,7 +151,7 @@ class JumpForwardMap:
|
||||
)
|
||||
|
||||
|
||||
class JumpForwardCache(BaseCache):
|
||||
class JumpForwardCache(BaseToolCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
@@ -486,15 +487,33 @@ class Batch:
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][: seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
@@ -575,6 +594,7 @@ class Batch:
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
|
||||
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
ForwardMode,
|
||||
Req,
|
||||
)
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -144,11 +145,20 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
if (
|
||||
server_args.chunked_prefill_size is not None
|
||||
and server_args.disable_radix_cache
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = PolicyScheduler(
|
||||
self.schedule_policy,
|
||||
@@ -354,7 +364,10 @@ class ModelTpServer:
|
||||
# Compute matched prefix length
|
||||
for req in self.waiting_queue:
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(
|
||||
rid=req.rid,
|
||||
key=req.input_ids,
|
||||
)
|
||||
if req.return_logprob:
|
||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
||||
@@ -614,6 +627,7 @@ class ModelTpServer:
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.input_ids),
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
@@ -771,6 +785,7 @@ class ModelTpServer:
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
|
||||
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
"""Cache can be indexed by either rid or key."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inc_lock_ref(self, node):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dec_lock_ref(self, node):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evictable_size(self):
|
||||
pass
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError
|
||||
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
def __init__(self, rid, value):
|
||||
self.rid = rid
|
||||
self.value = value
|
||||
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
|
||||
def match_prefix(self, rid, **kwargs):
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
entry = self.entries[rid]
|
||||
return entry.value, entry
|
||||
|
||||
def cache_req(
|
||||
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
||||
):
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
if del_in_memory_pool:
|
||||
assert rid in self.entries
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.token_to_kv_pool.free(indices)
|
||||
return
|
||||
|
||||
if rid not in self.entries:
|
||||
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
||||
|
||||
entry = self.entries[rid]
|
||||
entry.value = indices
|
||||
return indices, entry
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
pass
|
||||
|
||||
def inc_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def dec_lock_ref(self, node):
|
||||
return 0
|
||||
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
@@ -23,6 +23,8 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache:
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
@@ -62,7 +64,7 @@ class RadixCache:
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
def match_prefix(self, key):
|
||||
def match_prefix(self, key, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
@@ -90,6 +92,7 @@ class RadixCache:
|
||||
req_pool_idx,
|
||||
del_in_memory_pool=True,
|
||||
old_last_node=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Insert the request into radix cache
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
|
||||
@@ -419,10 +419,6 @@ class ServerArgs:
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
|
||||
assert not (
|
||||
self.chunked_prefill_size is not None and self.disable_radix_cache
|
||||
), "chunked prefill is not supported with radix cache disabled currently"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
|
||||
Reference in New Issue
Block a user