Add a new event loop (#1677)
This commit is contained in:
@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache):
|
||||
max_prefix_len = len(key)
|
||||
return entry.value[:max_prefix_len], entry
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_finished_req(
|
||||
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
||||
):
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_id_len = len(token_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, : token_id_len + free_delta
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_ids = req.fill_ids
|
||||
token_id_len = len(req.fill_ids)
|
||||
else:
|
||||
token_id_len = len(token_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :token_id_len
|
||||
]
|
||||
|
||||
if req.rid not in self.entries:
|
||||
|
||||
@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
def cache_finished_req(
|
||||
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
||||
):
|
||||
"""Cache request when it finishes."""
|
||||
if self.disable:
|
||||
if token_ids is None:
|
||||
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
else:
|
||||
token_ids_len = len(token_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : token_ids_len + free_delta
|
||||
]
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.disable:
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
if free_delta:
|
||||
self.token_to_kv_pool.free(
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
|
||||
]
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
Reference in New Issue
Block a user