[Feature] Support mamba radix cache v0 (#11214)
Co-authored-by: hanming-lu <hanming@x.ai> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: thalahors <ericalcaide1@gmail.com>
This commit is contained in:
@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import (
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
@@ -522,6 +523,7 @@ class Req:
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
@@ -727,7 +729,12 @@ class Req:
|
||||
self.last_host_node,
|
||||
self.host_hit_length,
|
||||
) = tree_cache.match_prefix(
|
||||
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
|
||||
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
|
||||
**(
|
||||
{"req": self, "cow_mamba": True}
|
||||
if isinstance(tree_cache, MambaRadixCache)
|
||||
else {}
|
||||
),
|
||||
)
|
||||
self.last_matched_prefix_len = len(self.prefix_indices)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
@@ -877,6 +884,7 @@ class Req:
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
self.mamba_pool_idx = None
|
||||
self.already_computed = 0
|
||||
|
||||
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||
@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
||||
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
||||
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
||||
if mamba_available_size < num_reqs:
|
||||
if self.tree_cache is not None and isinstance(
|
||||
self.tree_cache, MambaRadixCache
|
||||
):
|
||||
mamba_num = max(0, num_reqs - mamba_available_size)
|
||||
self.tree_cache.evict_mamba(mamba_num)
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
||||
else:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"alloc_req_slots runs out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`. "
|
||||
f"{self.req_to_token_pool.available_size()=}, "
|
||||
f"{num_reqs=}, "
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
def allocate_for_eagle_v2(self):
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
Reference in New Issue
Block a user