[PD] Add get_contiguous_buf_infos interface for MLATokenToKVPool (#5204)
This commit is contained in:
@@ -442,6 +442,19 @@ class MLATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
|
# for disagg
|
||||||
|
def get_contiguous_buf_infos(self):
|
||||||
|
kv_data_ptrs = [
|
||||||
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
||||||
|
kv_data_lens = [
|
||||||
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||||
|
kv_item_lens = [
|
||||||
|
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
||||||
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id)
|
self.layer_transfer_counter.wait_until(layer_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user