[Feat][Graph]Support FULL_DECEDE_ONLY mode for MLA models (#3125)
### What this PR does / why we need it?
Adds support for capturing the Multi-Layer Attention (MLA) decode
operation into an ACL graph. This improves performance by compiling the
attention kernel for single-token decoding.
Key changes include:
- Implementing the graph capture logic for the MLA kernel, including
workspace management and parameter updates.
- Modifying the rotary embedding (RoPE) handling to use pre-allocated
tensors, which is a requirement for graph capture.
- Adding a `build_for_graph_capture` method to the MLA metadata builder
to create dummy metadata during the graph compilation phase.
Known issues:
- Currently, MTP is not supported in FULL_DECEDE_ONLY mode -- we're
working on a fix
- We are preparing to remove update_mla_attn_params with
auto_dispatch_capture
### Does this PR introduce _any_ user-facing change?
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: panchao-hub <315134829@qq.com>
Signed-off-by: p00465316 <panchao13@huawei.com>
Co-authored-by: p00465316 <panchao13@huawei.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -229,6 +229,52 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
spec_attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, workspace, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
workspace=workspace,
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
|
||||
Reference in New Issue
Block a user