[feature] kv transfer support of ascend npu (#7795)

Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
ronnie_zheng
2025-07-11 10:07:51 +03:00
committed by GitHub
parent 615553079d
commit 86044712c6
10 changed files with 267 additions and 53 deletions

View File

@@ -604,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
# Continuous memory improves the efficiency of Ascend`s transmission backend,
# while other backends remain unchanged.
self.kv_buffer = torch.zeros(
(
2,
self.layer_num,
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self.k_buffer = self.kv_buffer[0]
self.v_buffer = self.kv_buffer[1]
# for disagg
def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [
self.get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
self,
@@ -969,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(layer_num)
]
self.kv_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self.layer_transfer_counter = None
@@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
)
self.mem_usage = kv_size / GB
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
self,
layer: RadixAttention,