Fix mamba radix cache eviction logic in alloc_req_slots (#11616)
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
This commit is contained in:
@@ -66,7 +66,7 @@ from sglang.srt.mem_cache.common import (
|
|||||||
evict_from_tree_cache,
|
evict_from_tree_cache,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||||
@@ -1080,27 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
|
||||||
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
|
||||||
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
|
||||||
if mamba_available_size < num_reqs:
|
|
||||||
if self.tree_cache is not None and isinstance(
|
|
||||||
self.tree_cache, MambaRadixCache
|
|
||||||
):
|
|
||||||
mamba_num = max(0, num_reqs - mamba_available_size)
|
|
||||||
self.tree_cache.evict_mamba(mamba_num)
|
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
|
||||||
else:
|
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
|
||||||
if req_pool_indices is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"alloc_req_slots runs out of memory. "
|
|
||||||
"Please set a smaller number for `--max-running-requests`. "
|
|
||||||
f"{self.req_to_token_pool.available_size()=}, "
|
|
||||||
f"{num_reqs=}, "
|
|
||||||
)
|
|
||||||
return req_pool_indices
|
|
||||||
|
|
||||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||||
self.encoder_lens_cpu = []
|
self.encoder_lens_cpu = []
|
||||||
self.encoder_cached = []
|
self.encoder_cached = []
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import triton.language as tl
|
|||||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||||
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.utils import support_triton
|
from sglang.srt.utils import support_triton
|
||||||
@@ -292,9 +293,15 @@ def alloc_req_slots(
|
|||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
reqs: list[Req] | None,
|
reqs: list[Req] | None,
|
||||||
|
tree_cache: BasePrefixCache | None,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Allocate request slots from the pool."""
|
"""Allocate request slots from the pool."""
|
||||||
if isinstance(req_to_token_pool, HybridReqToTokenPool):
|
if isinstance(req_to_token_pool, HybridReqToTokenPool):
|
||||||
|
mamba_available_size = req_to_token_pool.mamba_pool.available_size()
|
||||||
|
if mamba_available_size < num_reqs:
|
||||||
|
if tree_cache is not None and isinstance(tree_cache, MambaRadixCache):
|
||||||
|
mamba_num = max(0, num_reqs - mamba_available_size)
|
||||||
|
tree_cache.evict_mamba(mamba_num)
|
||||||
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
|
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
|
||||||
else:
|
else:
|
||||||
req_pool_indices = req_to_token_pool.alloc(num_reqs)
|
req_pool_indices = req_to_token_pool.alloc(num_reqs)
|
||||||
@@ -337,7 +344,9 @@ def alloc_for_extend(
|
|||||||
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
|
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
|
||||||
|
|
||||||
# Allocate req slots
|
# Allocate req slots
|
||||||
req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs)
|
req_pool_indices = alloc_req_slots(
|
||||||
|
batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
|
||||||
|
)
|
||||||
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
|
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
|
||||||
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
|
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user