[CP&SP] Integrate FIA operator in mla_cp._forward_decode (#5641)
### What this PR does / why we need it?
Replace the npu_multi_head_latent_attention with FIA operator in
mla_cp.py _forward_decode.
Adjust mla_attn_dpc_pcp in acl_graph.py
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Signed-off-by: tongyuzhou <t00886357@china.huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: tongyuzhou <t00886357@china.huawei.com>
This commit is contained in:
@@ -14,6 +14,8 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.attention.mla_v1 import (
|
||||
AscendMLADecodeMetadata,
|
||||
@@ -244,8 +246,12 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True)
|
||||
batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]]
|
||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||
decode_metadata.cp_seq_len = cp_seq_len
|
||||
decode_metadata.cp_seq_len = cp_seq_len.tolist()
|
||||
decode_metadata.batch_seq_mask = batch_seq_mask
|
||||
|
||||
actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1
|
||||
decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q
|
||||
|
||||
return decode_metadata
|
||||
|
||||
|
||||
@@ -535,18 +541,53 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
num_heads = self.num_heads * self.dcp_size
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim)
|
||||
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim)
|
||||
|
||||
actual_seq_lengths = None
|
||||
input_layout = "BNSD"
|
||||
|
||||
if (
|
||||
attn_metadata.attn_state
|
||||
in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]
|
||||
and self.speculative_config is not None
|
||||
):
|
||||
input_layout = "TND"
|
||||
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
||||
q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
common_kwargs = {
|
||||
"return_lse": True,
|
||||
"calc_type": "calc_type_ring",
|
||||
"query_rope": q_pe,
|
||||
"key_rope": k_pe,
|
||||
"num_heads": num_heads,
|
||||
"num_key_value_heads": self.num_kv_heads,
|
||||
"input_layout": input_layout,
|
||||
"atten_mask": spec_attn_mask,
|
||||
"sparse_mode": sparse_mode,
|
||||
"scale": self.scale,
|
||||
"antiquant_mode": 0,
|
||||
"antiquant_scale": None,
|
||||
"block_table": decode_meta.block_table,
|
||||
"block_size": block_size,
|
||||
"actual_seq_lengths": actual_seq_lengths,
|
||||
"actual_seq_lengths_kv": decode_meta.cp_seq_len,
|
||||
"softmax_lse_flag": True,
|
||||
}
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
@@ -560,72 +601,58 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
graph_params.events[num_tokens].append(event)
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
decode_meta.block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
k_nope,
|
||||
**common_kwargs,
|
||||
)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
|
||||
if input_layout == "BNSD":
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, device=q_nope.device)
|
||||
else:
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=torch.float, device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(
|
||||
weak_ref_tensors(q_nope),
|
||||
weak_ref_tensors(q_pe),
|
||||
weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe),
|
||||
weak_ref_tensors(k_pe),
|
||||
decode_meta.block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
input_layout,
|
||||
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None else None,
|
||||
sparse_mode,
|
||||
self.scale,
|
||||
weak_ref_tensors(decode_meta.block_table),
|
||||
block_size,
|
||||
actual_seq_lengths,
|
||||
decode_meta.cp_seq_len,
|
||||
weak_ref_tensors(attn_output),
|
||||
weak_ref_tensors(softmax_lse),
|
||||
)
|
||||
)
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.atb.npu_multi_head_latent_attention(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
decode_meta.block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
**common_kwargs,
|
||||
workspace=workspace,
|
||||
output=attn_output,
|
||||
lse=softmax_lse,
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
|
||||
)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
|
||||
torch_npu.atb.npu_multi_head_latent_attention(
|
||||
attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
decode_meta.block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
return_lse=True,
|
||||
calc_type="calc_type_ring",
|
||||
output=attn_output,
|
||||
lse=softmax_lse,
|
||||
k_nope,
|
||||
**common_kwargs,
|
||||
)
|
||||
if input_layout == "BNSD":
|
||||
B_attn, N_attn, S, D = attn_output.shape
|
||||
B_lse, N_lse, Q_S, _ = softmax_lse.shape
|
||||
|
||||
attn_output = attn_output.permute(0, 2, 1, 3).reshape(B_attn * S, N_attn, D)
|
||||
softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape(B_lse * Q_S, N_lse, 1)
|
||||
|
||||
# Update out&lse
|
||||
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask)
|
||||
|
||||
Reference in New Issue
Block a user