Unify the memory pool api and tp worker API (#1724)

This commit is contained in:
Lianmin Zheng
2024-10-19 23:19:26 -07:00
committed by GitHub
parent 95946271af
commit 59cbf47626
8 changed files with 87 additions and 25 deletions

View File

@@ -56,6 +56,12 @@ class ReqToTokenPool:
def clear(self):
self.free_slots = list(range(self.size))
def write(self, indices, values):
self.req_to_token[indices] = values
def get_write_records(self):
return None
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device
self.free_slots = None
self.is_not_in_free_group = True