[feature] kv transfer support of ascend npu (#7795)
Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
@@ -604,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
# Continuous memory improves the efficiency of Ascend`s transmission backend,
|
||||
# while other backends remain unchanged.
|
||||
self.kv_buffer = torch.zeros(
|
||||
(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.k_buffer = self.kv_buffer[0]
|
||||
self.v_buffer = self.kv_buffer[1]
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# layer_num x [seq_len, head_num, head_dim]
|
||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||
kv_data_ptrs = [
|
||||
self.get_key_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_data_lens = [
|
||||
self.get_key_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i)[0].nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
@@ -969,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.kv_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
@@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
|
||||
Reference in New Issue
Block a user