Refactor kv cache free (#11351)
This commit is contained in:
@@ -611,8 +611,8 @@ class DecodeTransferQueue:
|
|||||||
self.scheduler.stream_output(
|
self.scheduler.stream_output(
|
||||||
[decode_req.req], decode_req.req.return_logprob
|
[decode_req.req], decode_req.req.return_logprob
|
||||||
)
|
)
|
||||||
# unlock the kv cache or it will have memory leak
|
# release pre-allocated kv cache, but don't insert into the tree since it's failed
|
||||||
self.tree_cache.cache_finished_req(decode_req.req)
|
self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
if self.scheduler.enable_metrics:
|
if self.scheduler.enable_metrics:
|
||||||
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from sglang.srt.mem_cache.common import (
|
|||||||
alloc_for_decode,
|
alloc_for_decode,
|
||||||
alloc_for_extend,
|
alloc_for_extend,
|
||||||
alloc_token_slots,
|
alloc_token_slots,
|
||||||
|
evict_from_tree_cache,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||||
@@ -1406,7 +1407,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
* self.token_to_kv_pool_allocator.page_size
|
* self.token_to_kv_pool_allocator.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
evict_from_tree_cache(self.tree_cache, num_tokens)
|
||||||
return self._is_available_size_sufficient(num_tokens)
|
return self._is_available_size_sufficient(num_tokens)
|
||||||
|
|
||||||
def retract_decode(self, server_args: ServerArgs):
|
def retract_decode(self, server_args: ServerArgs):
|
||||||
@@ -1454,6 +1455,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
idx = sorted_indices.pop()
|
idx = sorted_indices.pop()
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
retracted_reqs.append(req)
|
retracted_reqs.append(req)
|
||||||
|
# release memory and don't insert into the tree because we need the space instantly
|
||||||
self.release_req(idx, len(sorted_indices), server_args)
|
self.release_req(idx, len(sorted_indices), server_args)
|
||||||
|
|
||||||
if len(retracted_reqs) == 0:
|
if len(retracted_reqs) == 0:
|
||||||
@@ -1478,39 +1480,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
|
||||||
|
|
||||||
if server_args.disaggregation_mode == "decode":
|
if server_args.disaggregation_mode == "decode":
|
||||||
req.offload_kv_cache(
|
req.offload_kv_cache(
|
||||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||||
)
|
)
|
||||||
if isinstance(self.tree_cache, ChunkCache):
|
# TODO (csy): for preempted requests, we may want to insert into the tree
|
||||||
# ChunkCache does not have eviction
|
self.tree_cache.cache_finished_req(req, is_insert=False)
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
|
||||||
]
|
evict_from_tree_cache(self.tree_cache, num_tokens)
|
||||||
self.token_to_kv_pool_allocator.free(token_indices)
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
else:
|
|
||||||
# TODO: apply more fine-grained retraction
|
|
||||||
last_uncached_pos = (
|
|
||||||
len(req.prefix_indices) // server_args.page_size
|
|
||||||
) * server_args.page_size
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
|
||||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
|
||||||
]
|
|
||||||
self.token_to_kv_pool_allocator.free(token_indices)
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
|
|
||||||
# release the last node
|
|
||||||
if self.is_hybrid:
|
|
||||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
|
||||||
else:
|
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
|
||||||
|
|
||||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
|
||||||
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
|
||||||
|
|
||||||
req.reset_for_retract()
|
req.reset_for_retract()
|
||||||
|
|
||||||
@@ -1808,24 +1787,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
enable_overlap=self.enable_overlap,
|
enable_overlap=self.enable_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
|
||||||
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.is_hybrid:
|
|
||||||
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
||||||
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
|
||||||
|
|
||||||
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
|
||||||
if self.tree_cache is not None:
|
|
||||||
full_num_tokens = max(0, num_tokens - full_available_size)
|
|
||||||
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
|
||||||
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
|
||||||
else:
|
|
||||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
|
||||||
if self.tree_cache is not None:
|
|
||||||
self.tree_cache.evict(num_tokens)
|
|
||||||
|
|
||||||
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class BasePrefixCache(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cache_finished_req(self, req: Req, **kwargs):
|
def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
last_host_node=None,
|
last_host_node=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req, insert: bool = True):
|
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx,
|
req.req_pool_idx,
|
||||||
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
||||||
|
|||||||
@@ -330,18 +330,18 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
return self._insert_helper(self.root_node, key, value)
|
return self._insert_helper(self.root_node, key, value)
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req):
|
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
|
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
if self.disable:
|
if self.disable:
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
req.req_pool_idx, :all_token_len
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
return
|
return
|
||||||
|
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||||
all_token_len = len(token_ids)
|
|
||||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||||
@@ -354,12 +354,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||||
dtype=torch.int64, copy=True
|
dtype=torch.int64, copy=True
|
||||||
)
|
)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
||||||
else:
|
else:
|
||||||
page_aligned_len = actual_kv_len
|
page_aligned_len = actual_kv_len
|
||||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||||
if self.is_eagle:
|
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
||||||
|
|
||||||
page_aligned_token_len = (
|
page_aligned_token_len = (
|
||||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||||
@@ -372,11 +369,22 @@ class RadixCache(BasePrefixCache):
|
|||||||
old_prefix_len -= 1
|
old_prefix_len -= 1
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(
|
if is_insert:
|
||||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
new_prefix_len = self.insert(
|
||||||
page_aligned_kv_indices,
|
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||||
)
|
page_aligned_kv_indices,
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
)
|
||||||
|
# Free the duplicates that were already in the tree
|
||||||
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
kv_indices[old_prefix_len:new_prefix_len]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
kv_indices[old_prefix_len:page_aligned_len]
|
||||||
|
)
|
||||||
|
|
||||||
|
# free the unaligned tail
|
||||||
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|||||||
@@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache):
|
|||||||
def total_size(self):
|
def total_size(self):
|
||||||
return self.tree.total_size()
|
return self.tree.total_size()
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req):
|
def cache_finished_req(self, req: Req, is_insert: bool = True):
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
assert req.req_pool_idx is not None
|
assert req.req_pool_idx is not None
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
|
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||||
overall_len = len(token_ids) # prefill + decode
|
overall_len = len(token_ids) # prefill + decode
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
||||||
|
|
||||||
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
||||||
# it will automatically align them, but length of them should be equal
|
# it will automatically align them, but length of them should be equal
|
||||||
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
||||||
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
|
page_aligned_overall_len = overall_len // self.page_size * self.page_size
|
||||||
|
|
||||||
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
if is_insert:
|
||||||
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
new_prefix_len = self._insert(
|
||||||
|
RadixKey(token_ids, req.extra_key), kv_indices
|
||||||
# KVCache between old & new is newly generated, but already exists in the pool
|
)
|
||||||
# we need to free this newly generated kv indices
|
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
||||||
if old_prefix_len < new_prefix_len:
|
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
||||||
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
# Free duplicates that were already in the pool
|
||||||
|
if old_prefix_len < new_prefix_len:
|
||||||
|
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool.free(
|
||||||
|
kv_indices[old_prefix_len:page_aligned_overall_len]
|
||||||
|
)
|
||||||
|
|
||||||
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
||||||
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
if page_aligned_overall_len < overall_len:
|
||||||
(unaligned_len := overall_len % self.page_size) > 0
|
|
||||||
):
|
|
||||||
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
||||||
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:])
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.dec_lock_ref(req.last_node)
|
self.dec_lock_ref(req.last_node)
|
||||||
|
|||||||
@@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache):
|
|||||||
|
|
||||||
return base_res
|
return base_res
|
||||||
|
|
||||||
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
|
def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override]
|
||||||
"""On request completion, insert device KV into radix and store to LMCache."""
|
"""On request completion, insert device KV into radix and store to LMCache."""
|
||||||
|
|
||||||
super().cache_finished_req(req)
|
super().cache_finished_req(req, is_insert=is_insert)
|
||||||
|
if not is_insert:
|
||||||
|
return
|
||||||
|
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
|||||||
@@ -427,19 +427,18 @@ class SWARadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req) -> None:
|
def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
|
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
if self.disable:
|
if self.disable:
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx,
|
req.req_pool_idx, :all_token_len
|
||||||
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
return
|
return
|
||||||
|
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
|
||||||
all_token_len = len(token_ids)
|
|
||||||
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
||||||
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
||||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||||
@@ -452,7 +451,6 @@ class SWARadixCache(BasePrefixCache):
|
|||||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||||
dtype=torch.int64, copy=True
|
dtype=torch.int64, copy=True
|
||||||
)
|
)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
|
||||||
else:
|
else:
|
||||||
page_aligned_len = actual_kv_len
|
page_aligned_len = actual_kv_len
|
||||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||||
@@ -472,11 +470,19 @@ class SWARadixCache(BasePrefixCache):
|
|||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
# insert the token_ids and kv_indices into the radix tree
|
# insert the token_ids and kv_indices into the radix tree
|
||||||
# Note: the insert function already frees the overlapped kv_indices
|
# Note: the insert function already frees the overlapped kv_indices
|
||||||
new_prefix_len = self.insert(
|
if is_insert:
|
||||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
new_prefix_len = self.insert(
|
||||||
page_aligned_kv_indices,
|
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||||
old_prefix_len,
|
page_aligned_kv_indices,
|
||||||
)
|
old_prefix_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
kv_indices[old_prefix_len:page_aligned_len]
|
||||||
|
)
|
||||||
|
|
||||||
|
# free the unaligned tail
|
||||||
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|||||||
Reference in New Issue
Block a user