Unify the memory pool api and tp worker API (#1724)

This commit is contained in:
Lianmin Zheng
2024-10-19 23:19:26 -07:00
committed by GitHub
parent 95946271af
commit 59cbf47626
8 changed files with 87 additions and 25 deletions

View File

@@ -56,6 +56,12 @@ class ReqToTokenPool:
def clear(self):
self.free_slots = list(range(self.size))
def write(self, indices, values):
self.req_to_token[indices] = values
def get_write_records(self):
return None
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device
self.free_slots = None
self.is_not_in_free_group = True

View File

@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
] = new_indices[len(req.prefix_indices) :]
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)