Refactor kv cache free (#11351)
This commit is contained in:
@@ -611,8 +611,8 @@ class DecodeTransferQueue:
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
# unlock the kv cache or it will have memory leak
|
||||
self.tree_cache.cache_finished_req(decode_req.req)
|
||||
# release pre-allocated kv cache, but don't insert into the tree since it's failed
|
||||
self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
|
||||
indices_to_remove.add(i)
|
||||
if self.scheduler.enable_metrics:
|
||||
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
||||
|
||||
@@ -64,6 +64,7 @@ from sglang.srt.mem_cache.common import (
|
||||
alloc_for_decode,
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
evict_from_tree_cache,
|
||||
)
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
@@ -1406,7 +1407,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
* self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
evict_from_tree_cache(self.tree_cache, num_tokens)
|
||||
return self._is_available_size_sufficient(num_tokens)
|
||||
|
||||
def retract_decode(self, server_args: ServerArgs):
|
||||
@@ -1454,6 +1455,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
# release memory and don't insert into the tree because we need the space instantly
|
||||
self.release_req(idx, len(sorted_indices), server_args)
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
@@ -1478,39 +1480,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = (
|
||||
len(req.prefix_indices) // server_args.page_size
|
||||
) * server_args.page_size
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
# TODO (csy): for preempted requests, we may want to insert into the tree
|
||||
self.tree_cache.cache_finished_req(req, is_insert=False)
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
|
||||
evict_from_tree_cache(self.tree_cache, num_tokens)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
@@ -1808,24 +1787,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
enable_overlap=self.enable_overlap,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
||||
return
|
||||
|
||||
if self.is_hybrid:
|
||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
||||
|
||||
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
full_num_tokens = max(0, num_tokens - full_available_size)
|
||||
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
||||
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
||||
else:
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
||||
if self.is_hybrid:
|
||||
return (
|
||||
|
||||
@@ -40,7 +40,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_finished_req(self, req: Req, **kwargs):
|
||||
def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -49,7 +49,7 @@ class ChunkCache(BasePrefixCache):
|
||||
last_host_node=None,
|
||||
)
|
||||
|
||||
def cache_finished_req(self, req: Req, insert: bool = True):
|
||||
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx,
|
||||
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
||||
|
||||
@@ -330,18 +330,18 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||
"""Cache request when it finishes."""
|
||||
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
@@ -354,12 +354,9 @@ class RadixCache(BasePrefixCache):
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
if self.is_eagle:
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
@@ -372,11 +369,22 @@ class RadixCache(BasePrefixCache):
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
if is_insert:
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
# Free the duplicates that were already in the tree
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[old_prefix_len:new_prefix_len]
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[old_prefix_len:page_aligned_len]
|
||||
)
|
||||
|
||||
# free the unaligned tail
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
@@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache):
|
||||
def total_size(self):
|
||||
return self.tree.total_size()
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||
"""Cache request when it finishes."""
|
||||
assert req.req_pool_idx is not None
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||
overall_len = len(token_ids) # prefill + decode
|
||||
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
||||
|
||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||
# it will automatically align them, but length of them should be equal
|
||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
||||
page_aligned_overall_len = overall_len // self.page_size * self.page_size
|
||||
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
|
||||
# KVCache between old & new is newly generated, but already exists in the pool
|
||||
# we need to free this newly generated kv indices
|
||||
if old_prefix_len < new_prefix_len:
|
||||
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
if is_insert:
|
||||
new_prefix_len = self._insert(
|
||||
RadixKey(token_ids, req.extra_key), kv_indices
|
||||
)
|
||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||
# Free duplicates that were already in the pool
|
||||
if old_prefix_len < new_prefix_len:
|
||||
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
else:
|
||||
self.token_to_kv_pool.free(
|
||||
kv_indices[old_prefix_len:page_aligned_overall_len]
|
||||
)
|
||||
|
||||
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
||||
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
||||
(unaligned_len := overall_len % self.page_size) > 0
|
||||
):
|
||||
if page_aligned_overall_len < overall_len:
|
||||
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
||||
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
||||
self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
@@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache):
|
||||
|
||||
return base_res
|
||||
|
||||
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
||||
def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override]
|
||||
"""On request completion, insert device KV into radix and store to LMCache."""
|
||||
|
||||
super().cache_finished_req(req)
|
||||
super().cache_finished_req(req, is_insert=is_insert)
|
||||
if not is_insert:
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
|
||||
@@ -427,19 +427,18 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
||||
|
||||
def cache_finished_req(self, req: Req) -> None:
|
||||
def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
|
||||
"""Cache request when it finishes."""
|
||||
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
if self.disable:
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx,
|
||||
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
@@ -452,7 +451,6 @@ class SWARadixCache(BasePrefixCache):
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
@@ -472,11 +470,19 @@ class SWARadixCache(BasePrefixCache):
|
||||
# Radix Cache takes one ref in memory pool
|
||||
# insert the token_ids and kv_indices into the radix tree
|
||||
# Note: the insert function already frees the overlapped kv_indices
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
old_prefix_len,
|
||||
)
|
||||
if is_insert:
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
old_prefix_len,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[old_prefix_len:page_aligned_len]
|
||||
)
|
||||
|
||||
# free the unaligned tail
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
Reference in New Issue
Block a user