Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -56,6 +56,12 @@ class ReqToTokenPool:
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
def write(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def get_write_records(self):
|
||||
return None
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
|
||||
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
||||
] = new_indices[len(req.prefix_indices) :]
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
Reference in New Issue
Block a user