[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads = self.num_heads
# 1. Compute out&lse by "npu_fused_infer_attention_score"
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=attn_metadata.block_tables,
block_size=self.key_cache.shape[1],
actual_seq_lengths_kv=attn_metadata.decode_meta.
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
common_kwargs = {
'num_heads':
num_heads,
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
"BSND",
'atten_mask':
None,
'scale':
self.scale,
'antiquant_mode':
0,
'antiquant_scale':
None,
'softmax_lse_flag':
True,
'block_table':
attn_metadata.block_tables,
'block_size':
self.key_cache.shape[1],
"actual_seq_lengths_kv":
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
)
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank],
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
self.pcp_rank, self.dcp_rank, self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
**common_kwargs,
workspace=workspace,
out=[attn_out, attn_lse])
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])