[refactor] refactor model runner capture model (#5230)

### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.

Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
weiguihua2
2025-12-30 08:32:14 +08:00
committed by GitHub
parent 5e96f94d2a
commit 15d73f248e
10 changed files with 142 additions and 254 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
import torch
import torch_npu
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear,
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
@@ -113,9 +114,6 @@ M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -159,6 +157,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA