[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it?
dcp pcp support full aclgraph, including mla attention_v1
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_heads = self.num_heads
|
||||
|
||||
# 1. Compute out&lse by "npu_fused_infer_attention_score"
|
||||
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
|
||||
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
||||
self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1),
|
||||
self.value_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1),
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=self.key_cache.shape[1],
|
||||
actual_seq_lengths_kv=attn_metadata.decode_meta.
|
||||
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
||||
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1)
|
||||
value = self.value_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1)
|
||||
common_kwargs = {
|
||||
'num_heads':
|
||||
num_heads,
|
||||
'num_key_value_heads':
|
||||
self.num_kv_heads,
|
||||
'input_layout':
|
||||
"BSND",
|
||||
'atten_mask':
|
||||
None,
|
||||
'scale':
|
||||
self.scale,
|
||||
'antiquant_mode':
|
||||
0,
|
||||
'antiquant_scale':
|
||||
None,
|
||||
'softmax_lse_flag':
|
||||
True,
|
||||
'block_table':
|
||||
attn_metadata.block_tables,
|
||||
'block_size':
|
||||
self.key_cache.shape[1],
|
||||
"actual_seq_lengths_kv":
|
||||
attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||
)
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
attn_out = torch.empty_like(q_nope)
|
||||
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||
dtype=torch.float,
|
||||
device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||
self.scale, attn_metadata.block_tables,
|
||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
||||
self.dcp_rank],
|
||||
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
|
||||
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
value,
|
||||
**common_kwargs,
|
||||
workspace=workspace,
|
||||
out=[attn_out, attn_lse])
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
|
||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||
attn_out.shape[3])
|
||||
|
||||
Reference in New Issue
Block a user