[new feat] ascend backend support fia fusion kernel (#8328)
Co-authored-by: Even Zhou <even.y.zhou@outlook.com>
This commit is contained in:
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.device = model_runner.device
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
if self.use_mla:
|
||||
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.graph_mode = False
|
||||
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
|
||||
if not self.use_fia:
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
mask_length = 2048
|
||||
self.fia_mask = ~torch.tril(
|
||||
torch.ones(
|
||||
(mask_length, mask_length),
|
||||
dtype=torch.bool,
|
||||
device=model_runner.device,
|
||||
)
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens.cpu().int()
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
||||
forward_batch.extend_seq_lens_cpu
|
||||
)
|
||||
|
||||
self.graph_mode = False
|
||||
|
||||
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if not self.use_mla:
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if self.use_fia:
|
||||
"""FIA will support multi-bs in the later version of CANN"""
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
q_len_offset = 0
|
||||
for q_len in forward_batch.extend_seq_lens_cpu:
|
||||
attn_output[q_len_offset : q_len_offset + q_len] = (
|
||||
torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q[None, q_len_offset : q_len_offset + q_len],
|
||||
k[None, q_len_offset : q_len_offset + q_len],
|
||||
v[None, q_len_offset : q_len_offset + q_len],
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
|
||||
atten_mask=self.fia_mask.unsqueeze(0),
|
||||
sparse_mode=3,
|
||||
scale=layer.scaling,
|
||||
next_tokens=0,
|
||||
)[0]
|
||||
)
|
||||
q_len_offset += q_len
|
||||
attn_output = attn_output.view(
|
||||
-1, layer.tp_q_head_num * layer.v_head_dim
|
||||
)
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=attn_output,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
layer.qk_head_dim != layer.v_head_dim
|
||||
), "FIA only supports qk_head_dim != v_head_dim"
|
||||
q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if (
|
||||
layer.is_cross_attention
|
||||
or layer.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
causal = False
|
||||
|
||||
self.native_attn._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
k_cache.view(
|
||||
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
),
|
||||
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=causal,
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
v,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_rope,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
input_layout="TND",
|
||||
atten_mask=self.fia_mask,
|
||||
sparse_mode=3,
|
||||
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
|
||||
scale=layer.scaling,
|
||||
next_tokens=0,
|
||||
)
|
||||
return o
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
save_kv_cache: bool = False,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
num_tokens = q.shape[0]
|
||||
if self.graph_mode:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
workspace = (
|
||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query,
|
||||
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
)
|
||||
)
|
||||
output = torch.empty(
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
else:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
if self.use_fia:
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
layer.qk_head_dim,
|
||||
),
|
||||
k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||
),
|
||||
v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||
),
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
block_size=self.page_size,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale=layer.scaling,
|
||||
)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
num_tokens = q.shape[0]
|
||||
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
if (self.graph_mode or self.use_fia) and (
|
||||
layer.tp_q_head_num // layer.tp_k_head_num
|
||||
) >= 8:
|
||||
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
||||
kv_c = kv_c.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
||||
)
|
||||
k_pe = k_pe.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
|
||||
)
|
||||
q = q.view(
|
||||
forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
|
||||
)
|
||||
q_rope = q_rope.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
kv_c,
|
||||
kv_c,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_pe,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
sparse_mode=0,
|
||||
scale=layer.scaling,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.graph_mode == False
|
||||
) # _npu_paged_attention_mla not support graph mode
|
||||
q = torch.cat([q, q_rope], dim=-1)
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=output,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output,
|
||||
)
|
||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
||||
|
||||
@@ -304,7 +304,7 @@ class TopK(CustomOp):
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
||||
|
||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
@@ -36,12 +36,15 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=cache_k,
|
||||
value=cache_v,
|
||||
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = torch.zeros(
|
||||
self.k_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
self.kv_lora_rank,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.v_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
assert hasattr(self, "k_buffer")
|
||||
assert hasattr(self, "v_buffer")
|
||||
kv_size_bytes = 0
|
||||
for k_cache in self.k_buffer:
|
||||
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
||||
for v_cache in self.v_buffer:
|
||||
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||
return kv_size_bytes
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
return (
|
||||
self.k_buffer[layer_id - self.start_layer],
|
||||
self.v_buffer[layer_id - self.start_layer],
|
||||
)
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.k_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.v_buffer[layer_id - self.start_layer]
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
||||
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||
]
|
||||
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
if cache_v is None:
|
||||
cache_k, cache_v = cache_k.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
||||
loc.view(-1, 1),
|
||||
cache_k.view(-1, 1, self.kv_lora_rank),
|
||||
)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
self.v_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, self.qk_rope_head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
loc.view(-1, 1),
|
||||
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
if attention_backend == "ascend":
|
||||
return AttnForwardMethod.MLA
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
elif (
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
or self.current_attention_backend == "flashinfer"
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
or self.current_attention_backend == "trtllm_mla"
|
||||
or self.current_attention_backend == "ascend"
|
||||
):
|
||||
extra_args = {}
|
||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
|
||||
Reference in New Issue
Block a user