Fix the race condition in overlap mode (#1712)

This commit is contained in:
Lianmin Zheng
2024-10-19 06:50:56 -07:00
committed by GitHub
parent 3db43d1b08
commit 769bf11c05
6 changed files with 21 additions and 38 deletions

View File

@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry
def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_id_len + free_delta
req.req_pool_idx, :token_id_len
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)

View File

@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it finishes."""
if self.disable:
if token_ids is None:
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_ids_len + free_delta
req.req_pool_idx, :token_ids_len
]
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if free_delta:
self.token_to_kv_pool.free(
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
]
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)