Support chunked prefill when radix cache is disabled (#811)
This commit is contained in:
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Base cache class."""
|
"""Base tool cache for constrained decoding tools."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class BaseCache:
|
class BaseToolCache:
|
||||||
def __init__(self, enable=True):
|
def __init__(self, enable=True):
|
||||||
self.enable = enable
|
self.enable = enable
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -16,10 +16,10 @@ limitations under the License.
|
|||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
|
|
||||||
class FSMCache(BaseCache):
|
class FSMCache(BaseToolCache):
|
||||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
|
|||||||
make_byte_level_fsm,
|
make_byte_level_fsm,
|
||||||
make_deterministic_fsm,
|
make_deterministic_fsm,
|
||||||
)
|
)
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ class JumpForwardMap:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class JumpForwardCache(BaseCache):
|
class JumpForwardCache(BaseToolCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
|
|
||||||
@@ -486,16 +487,34 @@ class Batch:
|
|||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
retracted_reqs.append(req)
|
retracted_reqs.append(req)
|
||||||
|
|
||||||
|
if isinstance(self.tree_cache, ChunkCache):
|
||||||
|
# ChunkCache does not have eviction
|
||||||
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices_cpu[idx]
|
||||||
|
][: seq_lens_cpu[idx]]
|
||||||
|
self.token_to_kv_pool.free(token_indices)
|
||||||
|
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||||
|
del self.tree_cache.entries[req.rid]
|
||||||
|
else:
|
||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
last_uncached_pos = len(req.prefix_indices)
|
last_uncached_pos = len(req.prefix_indices)
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req_pool_indices_cpu[idx]
|
req_pool_indices_cpu[idx]
|
||||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool.free(token_indices)
|
||||||
|
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||||
|
|
||||||
# release the last node
|
# release the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||||
|
residual_size = (
|
||||||
|
len(sorted_indices) * global_config.retract_decode_steps
|
||||||
|
- self.token_to_kv_pool.available_size()
|
||||||
|
)
|
||||||
|
residual_size = max(0, residual_size)
|
||||||
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||||
|
|
||||||
req.prefix_indices = None
|
req.prefix_indices = None
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
@@ -575,6 +594,7 @@ class Batch:
|
|||||||
if req_pool_indices_cpu is None:
|
if req_pool_indices_cpu is None:
|
||||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
|
rid=req.rid,
|
||||||
token_ids=cur_all_ids,
|
token_ids=cur_all_ids,
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
ForwardMode,
|
ForwardMode,
|
||||||
Req,
|
Req,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -144,6 +145,15 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
|
if (
|
||||||
|
server_args.chunked_prefill_size is not None
|
||||||
|
and server_args.disable_radix_cache
|
||||||
|
):
|
||||||
|
self.tree_cache = ChunkCache(
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.tree_cache = RadixCache(
|
self.tree_cache = RadixCache(
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
@@ -354,7 +364,10 @@ class ModelTpServer:
|
|||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
req.input_ids = req.origin_input_ids + req.output_ids
|
req.input_ids = req.origin_input_ids + req.output_ids
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(
|
||||||
|
rid=req.rid,
|
||||||
|
key=req.input_ids,
|
||||||
|
)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||||
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
||||||
@@ -614,6 +627,7 @@ class ModelTpServer:
|
|||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
|
rid=req.rid,
|
||||||
token_ids=tuple(req.input_ids),
|
token_ids=tuple(req.input_ids),
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
@@ -771,6 +785,7 @@ class ModelTpServer:
|
|||||||
for i in finished_indices:
|
for i in finished_indices:
|
||||||
req = batch.reqs[i]
|
req = batch.reqs[i]
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
|
rid=req.rid,
|
||||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
|
|||||||
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
43
python/sglang/srt/mem_cache/base_cache.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BasePrefixCache(ABC):
|
||||||
|
"""Cache can be indexed by either rid or key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def match_prefix(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cache_req(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evict(self, num_tokens, evict_callback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def inc_lock_ref(self, node):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dec_lock_ref(self, node):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evictable_size(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def total_size(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
raise NotImplementedError
|
||||||
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
60
python/sglang/srt/mem_cache/chunk_cache.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkCacheEntry:
|
||||||
|
def __init__(self, rid, value):
|
||||||
|
self.rid = rid
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkCache(BasePrefixCache):
|
||||||
|
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
||||||
|
self.disable = True
|
||||||
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.entries = {}
|
||||||
|
|
||||||
|
def match_prefix(self, rid, **kwargs):
|
||||||
|
if rid not in self.entries:
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
entry = self.entries[rid]
|
||||||
|
return entry.value, entry
|
||||||
|
|
||||||
|
def cache_req(
|
||||||
|
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
||||||
|
):
|
||||||
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||||
|
if del_in_memory_pool:
|
||||||
|
assert rid in self.entries
|
||||||
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
|
self.token_to_kv_pool.free(indices)
|
||||||
|
return
|
||||||
|
|
||||||
|
if rid not in self.entries:
|
||||||
|
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
||||||
|
|
||||||
|
entry = self.entries[rid]
|
||||||
|
entry.value = indices
|
||||||
|
return indices, entry
|
||||||
|
|
||||||
|
def insert(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def evict(self, num_tokens, evict_callback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def inc_lock_ref(self, node):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def dec_lock_ref(self, node):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def evictable_size(self):
|
||||||
|
return 0
|
||||||
@@ -23,6 +23,8 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||||
|
|
||||||
|
|
||||||
class TreeNode:
|
class TreeNode:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
|
|||||||
return i
|
return i
|
||||||
|
|
||||||
|
|
||||||
class RadixCache:
|
class RadixCache(BasePrefixCache):
|
||||||
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
@@ -62,7 +64,7 @@ class RadixCache:
|
|||||||
self.root_node.lock_ref = 1
|
self.root_node.lock_ref = 1
|
||||||
self.evictable_size_ = 0
|
self.evictable_size_ = 0
|
||||||
|
|
||||||
def match_prefix(self, key):
|
def match_prefix(self, key, **kwargs):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return [], self.root_node
|
return [], self.root_node
|
||||||
|
|
||||||
@@ -90,6 +92,7 @@ class RadixCache:
|
|||||||
req_pool_idx,
|
req_pool_idx,
|
||||||
del_in_memory_pool=True,
|
del_in_memory_pool=True,
|
||||||
old_last_node=None,
|
old_last_node=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Insert the request into radix cache
|
# Insert the request into radix cache
|
||||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||||
|
|||||||
@@ -419,10 +419,6 @@ class ServerArgs:
|
|||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
|
|
||||||
assert not (
|
|
||||||
self.chunked_prefill_size is not None and self.disable_radix_cache
|
|
||||||
), "chunked prefill is not supported with radix cache disabled currently"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
|
|||||||
Reference in New Issue
Block a user