Ascend attention backend(PA&MLA) (#7722)

Co-authored-by: Maksim <makcum888e@mail.ru>
Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
ronnie_zheng
2025-07-03 19:23:19 +03:00
committed by GitHub
parent b58226510f
commit 1e0e549766
17 changed files with 842 additions and 16 deletions

View File

@@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
self.is_not_in_free_group = True
self.free_group = []
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
return num_new_pages
def alloc_decode_kernel_ascend(
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
):
num_new_pages = (seq_lens + page_size - 1) // page_size - (
seq_lens - 1 + page_size - 1
) // page_size
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
for i in range(len(seq_lens)):
if num_new_pages[i]:
out_indices[i] = free_pages[start_new_pages[i]] * page_size
else:
out_indices[i] = last_loc[i] + 1
return num_new_pages
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
super().__init__(size, page_size, dtype, device, kvcache)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
self.ret_values = alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
self.ret_values = alloc_decode_kernel_ascend(
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)

View File

@@ -568,6 +568,76 @@ class SWAKVPool(KVCache):
)
class AscendTokenToKVPool(MHATokenToKVPool):
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
key_cache=self.k_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
value_cache=self.v_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
slot_indices=loc,
)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
@@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize()
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super(MLATokenToKVPool, self).__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(layer_num)
]
self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
)
self.mem_usage = kv_size / GB
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache_siso(
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
),
slot_indices=loc,
)
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,