Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -56,6 +56,12 @@ class ReqToTokenPool:
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
def write(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def get_write_records(self):
|
||||
return None
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
|
||||
Reference in New Issue
Block a user