[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it?
dcp pcp support full aclgraph, including mla attention_v1
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
@@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
|
||||
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
|
||||
dcp_rank, dcp_size) = param
|
||||
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
|
||||
dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
if dcp_size > 1:
|
||||
num_heads = num_heads * dcp_size
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
runtime_shape, speculative_config):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
|
||||
num_kv_heads, attn_output, softmax_lse) = param
|
||||
|
||||
decode_meta = forward_context.attn_metadata[key].decode
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
|
||||
len(seq_len))
|
||||
else:
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
dtype=seq_len.dtype,
|
||||
device=seq_len.device)
|
||||
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.atb.npu_multi_head_latent_attention(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
return_lse=True,
|
||||
calc_type="calc_type_ring",
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
|
||||
Reference in New Issue
Block a user