[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -23,7 +23,6 @@ import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend,
@@ -40,6 +39,7 @@ from vllm.v1.attention.backends.registry import ( # type: ignore
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
from vllm_ascend.attention.utils import (
@@ -392,7 +392,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
@@ -444,7 +444,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
event.record(update_stream)
else:
# FIA update logic
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
@@ -462,7 +462,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
@@ -488,7 +488,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
softmax_lse,
) = param
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
@@ -535,8 +535,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
forward_context = get_forward_context()
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
@@ -563,7 +562,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
sparse_mode=3,
scale=self.scale,
)
if forward_context.is_draft_model:
if _EXTRA_CTX.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
@@ -625,9 +624,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor | None = None,
):
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
if _EXTRA_CTX.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
@@ -761,11 +759,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
forward_context: ForwardContext = get_forward_context()
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False):
if _EXTRA_CTX.capturing:
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
@@ -841,8 +838,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
if _EXTRA_CTX.capturing:
return self.full_graph_pa(query, attn_metadata, output)
torch_npu._npu_paged_attention(
query=query,