[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -23,8 +23,8 @@ from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device):
def build_attn_metadata(
*,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
@@ -54,6 +54,7 @@ def build_attn_metadata(
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
# extra attributes for ascend npus.
seq_lens_np: np.ndarray | None = None,
num_computed_tokens_cpu: torch.Tensor | None = None,
@@ -72,9 +73,6 @@ def build_attn_metadata(
if seq_lens_np is None:
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings = slot_mappings.to(torch.int32)
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -100,13 +98,14 @@ def build_attn_metadata(
max_seq_len=max_seq_len,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata, # type: ignore
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
for attn_group in attn_groups[i]:
attn_metadata_builder = attn_group.get_metadata_builder(0)
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata