Ascend attention backend(PA&MLA) (#7722)
Co-authored-by: Maksim <makcum888e@mail.ru> Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
@@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
|
||||
def alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
device,
|
||||
):
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
end_pos = torch.cumsum(extend_lens, 0)
|
||||
start_pos = end_pos - extend_lens
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
num_full_new_pages = (seq_lens) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
need_page = num_new_pages - num_full_new_pages
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
||||
for i in range(len(prefix_lens)):
|
||||
num1 = (
|
||||
min(
|
||||
seq_lens[i],
|
||||
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
||||
)
|
||||
- prefix_lens[i]
|
||||
)
|
||||
if num1:
|
||||
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
||||
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
||||
)
|
||||
|
||||
num2 = (
|
||||
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
||||
) * page_size
|
||||
if num2:
|
||||
pages = (
|
||||
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
||||
* page_size
|
||||
)
|
||||
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
||||
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
||||
).view(-1)
|
||||
|
||||
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
||||
if num3:
|
||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||
).view(-1)
|
||||
return num_new_pages
|
||||
|
||||
|
||||
def alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
):
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
seq_lens - 1 + page_size - 1
|
||||
) // page_size
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
for i in range(len(seq_lens)):
|
||||
if num_new_pages[i]:
|
||||
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
||||
else:
|
||||
out_indices[i] = last_loc[i] + 1
|
||||
return num_new_pages
|
||||
|
||||
|
||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.ret_values = alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
self.ret_values = alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.free_pages = self.free_pages.to(torch.int32)
|
||||
|
||||
@@ -568,6 +568,76 @@ class SWAKVPool(KVCache):
|
||||
)
|
||||
|
||||
|
||||
class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
if v_scale is not None:
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=cache_k,
|
||||
value=cache_v,
|
||||
key_cache=self.k_buffer[layer_id].view(
|
||||
-1, self.page_size, self.head_num, self.head_dim
|
||||
),
|
||||
value_cache=self.v_buffer[layer_id].view(
|
||||
-1, self.page_size, self.head_num, self.head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
kv_buffer_ptr,
|
||||
@@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
super(MLATokenToKVPool, self).__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
kv_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user