[refactor] refactor model runner capture model (#5230)
### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.
Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -44,9 +44,6 @@ from vllm_ascend.utils import weak_ref_tensors
|
||||
|
||||
|
||||
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
@@ -72,6 +69,16 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendAttentionCPMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.ALWAYS
|
||||
|
||||
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
|
||||
"""
|
||||
given 4-d list [req][pcp][dcp], return:
|
||||
|
||||
@@ -182,9 +182,6 @@ class AscendMetadata:
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
@@ -220,6 +217,16 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendAttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.ALWAYS
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import ClassVar, Optional, Tuple, TypeVar
|
||||
from typing import Optional, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -12,7 +12,7 @@ from vllm.distributed import (get_dcp_group,
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
@@ -37,9 +37,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -74,6 +71,16 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendMlaCPMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def set_num_actual_tokens(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -15,7 +14,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
@@ -182,9 +181,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -263,6 +259,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.query_lens: torch.Tensor = None
|
||||
self.seq_lens: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendMLAMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
@@ -113,9 +114,6 @@ M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -159,6 +157,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY
|
||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendSFAMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# No need to reorder for Ascend SFA
|
||||
|
||||
Reference in New Issue
Block a user