[refactor] refactor model runner capture model (#5230)

### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.

Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
weiguihua2
2025-12-30 08:32:14 +08:00
committed by GitHub
parent 5e96f94d2a
commit 15d73f248e
10 changed files with 142 additions and 254 deletions

View File

@@ -54,10 +54,10 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
return wrapped return wrapped
original_capture = NPUModelRunner._capture_model original_capture = NPUModelRunner.capture_model
with patch.object(NPUModelRunner, with patch.object(NPUModelRunner,
'_capture_model', 'capture_model',
new=capture_model_wrapper(original_capture)): new=capture_model_wrapper(original_capture)):
prompts = [ prompts = [
"Hello, my name is", "The president of the United States is", "Hello, my name is", "The president of the United States is",
@@ -73,7 +73,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
vllm_model = VllmRunner(snapshot_download(model)) vllm_model = VllmRunner(snapshot_download(model))
_ = vllm_model.generate(prompts, sampling_params) _ = vllm_model.generate(prompts, sampling_params)
assert capture_called.value == 1, "_capture_model was not called during test" assert capture_called.value == 1, "capture_model was not called during test"
assert capture_mem_before.value != -1, "capture_mem_before not set" assert capture_mem_before.value != -1, "capture_mem_before not set"
assert capture_mem_after.value != -1, "capture_mem_after not set" assert capture_mem_after.value != -1, "capture_mem_after not set"
@@ -93,7 +93,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
max_mem_expected = max_capture_mem_gib * (1024**3) max_mem_expected = max_capture_mem_gib * (1024**3)
assert mem_used_by_capture < max_mem_expected, ( assert mem_used_by_capture < max_mem_expected, (
f"_capture_model used more memory than expected. " f"capture_model used more memory than expected. "
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, " f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
f"Expected: < {max_capture_mem_gib:.2f} GiB") f"Expected: < {max_capture_mem_gib:.2f} GiB")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn' os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'

View File

@@ -2,7 +2,6 @@ import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -98,7 +97,6 @@ class TestAscendSFAMetadataBuilder(TestBase):
vllm_config=vllm_config, vllm_config=vllm_config,
device=device) device=device)
assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
assert builder.device == device assert builder.device == device
assert builder.vllm_config == vllm_config assert builder.vllm_config == vllm_config

View File

@@ -44,9 +44,6 @@ from vllm_ascend.utils import weak_ref_tensors
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
@@ -72,6 +69,16 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
self.dcp_rank = get_decode_context_model_parallel_rank( self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0 ) if self.dcp_size > 1 else 0
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionCPMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]: def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
""" """
given 4-d list [req][pcp][dcp], return: given 4-d list [req][pcp][dcp], return:

View File

@@ -182,9 +182,6 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
@@ -220,6 +217,16 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch, def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
return False return False

View File

@@ -1,4 +1,4 @@
from typing import ClassVar, Optional, Tuple, TypeVar from typing import Optional, Tuple, TypeVar
import numpy as np import numpy as np
import torch import torch
@@ -12,7 +12,7 @@ from vllm.distributed import (get_dcp_group,
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
# isort: off # isort: off
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
@@ -37,9 +37,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@@ -74,6 +71,16 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
dtype=torch.uint8, dtype=torch.uint8,
device=device) device=device)
@classmethod
def get_cudagraph_support(
cls: type["AscendMlaCPMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def set_num_actual_tokens( def set_num_actual_tokens(
self, self,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
TypeVar)
import numpy as np import numpy as np
import torch import torch
@@ -15,7 +14,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
@@ -182,9 +181,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@@ -263,6 +259,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.query_lens: torch.Tensor = None self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None self.seq_lens: torch.Tensor = None
@classmethod
def get_cudagraph_support(
cls: type["AscendMLAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def reorder_batch(self, input_batch: "NPUInputBatch", def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at # We now want to reorder the batch so that the "decode" requests are at

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
import torch import torch
import torch_npu import torch_npu
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear,
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
@@ -113,9 +114,6 @@ M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@@ -159,6 +157,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
== CUDAGraphMode.FULL_DECODE_ONLY == CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def reorder_batch(self, input_batch: "NPUInputBatch", def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA # No need to reorder for Ascend SFA

View File

@@ -26,6 +26,8 @@ from vllm.platforms import Platform, PlatformEnum
# todo: please remove it when solve cuda hard code in vllm # todo: please remove it when solve cuda hard code in vllm
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
# todo: please remove it when support controls garbage collection during CUDA graph capture.
os.environ["VLLM_ENABLE_CUDAGRAPH_GC"] = "1"
from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.utils import refresh_block_size from vllm_ascend.utils import refresh_block_size
@@ -244,6 +246,12 @@ class NPUPlatform(Platform):
data_parallel_size, data_parallel_size,
) )
compilation_config.use_inductor = False compilation_config.use_inductor = False
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
# Since the process is created in the spawn mode, the value of the class attribute
# attention ops transmitted is still the one before modification, so it has not been modified.
# This will cause in scenarios where both piecewise and splitting ops are configured simultaneously,
# If splitting ops does not contain the vllm::mla forward value, this configuration issue will
# not be detected in advance assert.
compilation_config.splitting_ops.extend(["vllm::mla_forward"]) compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
ascend_config.enable_npugraph_ex = False ascend_config.enable_npugraph_ex = False

View File

@@ -18,7 +18,7 @@
# #
import math import math
import time import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy from copy import copy, deepcopy
@@ -27,16 +27,12 @@ from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
import numpy as np import numpy as np
import regex as re
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm # type: ignore
from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention, MLAAttention from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
@@ -46,8 +42,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
get_pcp_group, get_pp_group, get_pcp_group, get_pp_group,
get_tp_group, get_tp_group)
is_global_first_rank)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -58,8 +53,7 @@ from vllm.utils.import_utils import LazyLoader
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import CommonAttentionMetadata
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
@@ -1972,13 +1966,18 @@ class NPUModelRunner(GPUModelRunner):
self, self,
num_tokens: int, num_tokens: int,
with_prefill: bool = False, with_prefill: bool = False,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False, force_attention: bool = False,
uniform_decode: bool = False, uniform_decode: bool = False,
is_profile: bool = False, is_profile: bool = False,
allow_microbatching: bool = True,
skip_eplb: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# only support eager mode and piecewise graph now # only support eager mode and piecewise graph now
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
@@ -2054,15 +2053,15 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
# filter out the valid batch descriptor # filter out the valid batch descriptor
if aclgraph_runtime_mode is not None: if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support # we allow forcing NONE when the dispatcher disagrees to support
# warm ups for aclgraph capture # warm ups for aclgraph capture
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode: if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode:
raise ValueError( raise ValueError(
f"Aclgraph runtime mode mismatch at dummy_run. " f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.") f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}.")
else: else:
aclgraph_runtime_mode = _ag_mode cudagraph_runtime_mode = _ag_mode
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
# and not supported in ASCEND now. We could remove it in the future. # and not supported in ASCEND now. We could remove it in the future.
@@ -2071,7 +2070,7 @@ class NPUModelRunner(GPUModelRunner):
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_tokens=num_tokens_padded, num_tokens=num_tokens_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=cudagraph_runtime_mode,
force_attention=force_attention, force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
) )
@@ -2147,7 +2146,7 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=is_profile, in_profile_run=is_profile,
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
model_instance=self.model): model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states( hidden_states = self._generate_dummy_run_hidden_states(
@@ -2161,7 +2160,7 @@ class NPUModelRunner(GPUModelRunner):
with_prefill=with_prefill, with_prefill=with_prefill,
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits, dummy_compute_logits=dummy_drafter_compute_logits,
in_graph_capturing=not force_attention, in_graph_capturing=not force_attention,
@@ -2677,7 +2676,8 @@ class NPUModelRunner(GPUModelRunner):
def get_attn_backends_for_group( def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec, kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]: ) -> tuple[dict[AttentionGroupKey, list[str]],
set[type[AttentionBackend]]]:
layers = get_layers_from_vllm_config( layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase, self.vllm_config, AttentionLayerBase,
kv_cache_group_spec.layer_names) kv_cache_group_spec.layer_names)
@@ -2699,10 +2699,14 @@ class NPUModelRunner(GPUModelRunner):
attn_backends[key] = AttentionGroupKey(attn_backend, attn_backends[key] = AttentionGroupKey(attn_backend,
layer_kv_cache_spec) layer_kv_cache_spec)
attn_backend_layers[key].append(layer_name) attn_backend_layers[key].append(layer_name)
return { return (
attn_backends[k]: v {
for k, v in attn_backend_layers.items() attn_backends[k]: v
} for k, v in attn_backend_layers.items()
},
set(group_key.attn_backend
for group_key in attn_backends.values()),
)
def create_attn_groups(attn_backends_map: dict[AttentionBackend, def create_attn_groups(attn_backends_map: dict[AttentionBackend,
list[str]], list[str]],
@@ -2723,11 +2727,21 @@ class NPUModelRunner(GPUModelRunner):
attn_groups.append(attn_group) attn_groups.append(attn_group)
return attn_groups return attn_groups
attention_backend_maps = []
attention_backend_list = []
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
attention_backend_maps.append(attn_backends[0])
attention_backend_list.append(attn_backends[1])
self._check_and_update_cudagraph_mode(attention_backend_list,
kv_cache_config.kv_cache_groups)
for i, kv_cache_group_spec in enumerate( for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
attn_backends = get_attn_backends_for_group( # type: ignore attn_backends = get_attn_backends_for_group( # type: ignore
kv_cache_group_spec) kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends, i)) self.attn_groups.append(create_attn_groups(attn_backends[0], i))
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
@@ -2855,214 +2869,26 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec return kv_cache_spec
def initialize_aclgraph_capture(self) -> None: def _check_and_update_cudagraph_mode(
min_ag_support = AttentionCGSupport.ALWAYS self,
min_ag_builder_name = None attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec],
for attn_group in self._attn_group_iterator(): ) -> None:
builder = attn_group.get_metadata_builder() super()._check_and_update_cudagraph_mode(attention_backends,
graph_support = None kv_cache_groups)
if hasattr(builder, 'aclgraph_support'):
graph_support = builder.aclgraph_support.value
builder_aclgraph = builder.aclgraph_support
else:
graph_support = builder._cudagraph_support.value
builder_aclgraph = builder._cudagraph_support
if graph_support < min_ag_support.value:
min_ag_support = builder_aclgraph
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
and all(op in self.compilation_config.splitting_ops for op in [
"vllm.mla_forward",
]))
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_ag_support != AttentionCGSupport.ALWAYS:
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
f"with {min_ag_builder_name} backend (support: "
f"{min_ag_support})")
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \
and min_ag_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
f"supported with {min_ag_builder_name} backend ("
f"support:{min_ag_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and aclgraph_mode.separate_routine()
and self.uniform_decode_query_len > 1):
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len,
self.parallel_config.tensor_parallel_size)
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self.cudagraph_batch_sizes = (capture_sizes
if capture_sizes is not None else [])
# NOTE: Since aclgraph_batch_sizes cannot be determined until here, # NOTE: Since aclgraph_batch_sizes cannot be determined until here,
# we set the graph params right before initializing the keys. # we set the graph params right before initializing the keys.
set_graph_params(self.cudagraph_batch_sizes) if self.use_aclgraph:
if self.speculative_config: set_graph_params(self.cudagraph_batch_sizes)
set_draft_graph_params(self.cudagraph_batch_sizes) if self.speculative_config:
set_draft_graph_params(self.cudagraph_batch_sizes)
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def _capture_aclgraphs(self, compilation_cases: list[int],
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
"mode: %s, uniform_decode: %s", compilation_cases,
aclgraph_runtime_mode.name, uniform_decode)
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing ACL graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
aclgraph_runtime_mode.name))
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
if not self.get_kv_cache_spec():
self._dummy_run(2,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
self._dummy_run(num_tokens,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
uniform_decode=uniform_decode)
def _capture_model(self):
if not self.use_aclgraph:
logger.warning(
"Skipping ACL graph capture. To turn on ACL graph capture, "
"ensure `aclraph_mode` was not manually set to `NONE`")
return
else:
self.initialize_aclgraph_capture()
set_cudagraph_capturing_enabled(True)
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
aclgraph_mode = self.compilation_config.cudagraph_mode
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
try:
self._capture_aclgraphs(
compilation_cases,
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False)
except Exception as e:
error_msg = str(e)
error_code = '0x7020023'
pattern = r'retCode=([^,\s\.]+)'
match = re.search(pattern, error_msg)
if match:
retCode = match.group(1)
# Determine whether the error message is caused by stream capture failure.
if match and retCode == error_code:
logger.error(
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
"Recommended solutions:\n"
"1. Manually configure the compilation_config parameter "
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
f"{str(e)}")
raise
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.cudagraph_batch_sizes if
x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
def capture_model(self) -> None: def capture_model(self) -> None:
parent_module_name = self.__class__.__base__.__module__
compilation_counter.num_gpu_runner_capture_triggers += 1 with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper(
parent_module_name):
start_time = time.perf_counter() super().capture_model()
start_free_npu_memory = torch.npu.mem_get_info()[0]
self._capture_model()
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
npu_graph_size = start_free_npu_memory - end_free_npu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def _update_tokens_for_pcp(self, tokens): def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@@ -3473,6 +3299,8 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.npu.default_stream torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.stream = torch.npu.stream torch.cuda.stream = torch.npu.stream
torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.mem_get_info = torch.npu.mem_get_info
yield yield
except Exception: except Exception:
torch.cuda.Event = _EventPlaceholder torch.cuda.Event = _EventPlaceholder
@@ -3480,6 +3308,8 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = _StreamPlaceholder torch.cuda.default_stream = _StreamPlaceholder
torch.cuda.current_stream = _StreamPlaceholder torch.cuda.current_stream = _StreamPlaceholder
torch.cuda.stream = _StreamPlaceholder torch.cuda.stream = _StreamPlaceholder
torch.cuda.synchronize = _StreamPlaceholder
torch.cuda.mem_get_info = _StreamPlaceholder
finally: finally:
# if anything goes wrong, just patch it with a placeholder # if anything goes wrong, just patch it with a placeholder
torch.cuda.Event = _EventPlaceholder torch.cuda.Event = _EventPlaceholder
@@ -3487,3 +3317,16 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.npu.default_stream torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.stream = torch.npu.stream torch.cuda.stream = torch.npu.stream
torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.mem_get_info = torch.npu.mem_get_info
# TODO: This method will be removed subsequently and implemented in platform.
@contextmanager
def _replace_gpu_model_runner_function_wrapper(target_module_name):
try:
target_module = sys.modules[target_module_name]
setattr(target_module, "graph_capture", graph_capture)
yield
finally:
setattr(target_module, "graph_capture", graph_capture)

View File

@@ -46,6 +46,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput) DraftTokenIds, ModelRunnerOutput)
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
@@ -231,6 +232,9 @@ class NPUWorker(WorkerBase):
# in ray scenario. see https://github.com/vllm-project/vllm/pull/26845 # in ray scenario. see https://github.com/vllm-project/vllm/pull/26845
# for more details # for more details
self.device = self._init_device() self.device = self._init_device()
# Initialize workspace manager
num_ubatches = 1
init_workspace_manager(self.device, num_ubatches)
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
if self.use_v2_model_runner: if self.use_v2_model_runner:
logger.warning( logger.warning(