[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional
from unittest.mock import patch
import numpy as np
import torch
import torch_npu
import vllm.envs as envs
@@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
event.record(update_stream)
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
dcp_rank, dcp_size) = param
actual_seq_lengths_kv = forward_context.attn_metadata[
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
if dcp_size > 1:
num_heads = num_heads * dcp_size
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape, speculative_config):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
num_kv_heads, attn_output, softmax_lse) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp":
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
len(seq_len))
else:
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]