[refactor] refactor model runner capture model (#5230)

### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.

Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
weiguihua2
2025-12-30 08:32:14 +08:00
committed by GitHub
parent 5e96f94d2a
commit 15d73f248e
10 changed files with 142 additions and 254 deletions

View File

@@ -54,10 +54,10 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
return wrapped
original_capture = NPUModelRunner._capture_model
original_capture = NPUModelRunner.capture_model
with patch.object(NPUModelRunner,
'_capture_model',
'capture_model',
new=capture_model_wrapper(original_capture)):
prompts = [
"Hello, my name is", "The president of the United States is",
@@ -73,7 +73,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
vllm_model = VllmRunner(snapshot_download(model))
_ = vllm_model.generate(prompts, sampling_params)
assert capture_called.value == 1, "_capture_model was not called during test"
assert capture_called.value == 1, "capture_model was not called during test"
assert capture_mem_before.value != -1, "capture_mem_before not set"
assert capture_mem_after.value != -1, "capture_mem_after not set"
@@ -93,7 +93,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
max_mem_expected = max_capture_mem_gib * (1024**3)
assert mem_used_by_capture < max_mem_expected, (
f"_capture_model used more memory than expected. "
f"capture_model used more memory than expected. "
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
f"Expected: < {max_capture_mem_gib:.2f} GiB")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'

View File

@@ -2,7 +2,6 @@ import sys
from unittest.mock import MagicMock, patch
import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -98,7 +97,6 @@ class TestAscendSFAMetadataBuilder(TestBase):
vllm_config=vllm_config,
device=device)
assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
assert builder.device == device
assert builder.vllm_config == vllm_config

View File

@@ -44,9 +44,6 @@ from vllm_ascend.utils import weak_ref_tensors
class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
@@ -72,6 +69,16 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionCPMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
"""
given 4-d list [req][pcp][dcp], return:

View File

@@ -182,9 +182,6 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
@@ -220,6 +217,16 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
@classmethod
def get_cudagraph_support(
cls: type["AscendAttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.ALWAYS
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False

View File

@@ -1,4 +1,4 @@
from typing import ClassVar, Optional, Tuple, TypeVar
from typing import Optional, Tuple, TypeVar
import numpy as np
import torch
@@ -12,7 +12,7 @@ from vllm.distributed import (get_dcp_group,
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
# isort: off
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
@@ -37,9 +37,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -74,6 +71,16 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
dtype=torch.uint8,
device=device)
@classmethod
def get_cudagraph_support(
cls: type["AscendMlaCPMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def set_num_actual_tokens(
self,
common_attn_metadata: AscendCommonAttentionMetadata,

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
import numpy as np
import torch
@@ -15,7 +14,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
@@ -182,9 +181,6 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -263,6 +259,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None
@classmethod
def get_cudagraph_support(
cls: type["AscendMLAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
import torch
import torch_npu
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear,
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
@@ -113,9 +114,6 @@ M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -159,6 +157,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA

View File

@@ -26,6 +26,8 @@ from vllm.platforms import Platform, PlatformEnum
# todo: please remove it when solve cuda hard code in vllm
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
# todo: please remove it when support controls garbage collection during CUDA graph capture.
os.environ["VLLM_ENABLE_CUDAGRAPH_GC"] = "1"
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.utils import refresh_block_size
@@ -244,6 +246,12 @@ class NPUPlatform(Platform):
data_parallel_size,
)
compilation_config.use_inductor = False
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
# Since the process is created in the spawn mode, the value of the class attribute
# attention ops transmitted is still the one before modification, so it has not been modified.
# This will cause in scenarios where both piecewise and splitting ops are configured simultaneously,
# If splitting ops does not contain the vllm::mla forward value, this configuration issue will
# not be detected in advance assert.
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config)
ascend_config.enable_npugraph_ex = False

View File

@@ -18,7 +18,7 @@
#
import math
import time
import sys
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
@@ -27,16 +27,12 @@ from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import regex as re
import torch
import torch.distributed as dist
import torch.nn as nn
from tqdm import tqdm # type: ignore
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.selector import get_attn_backend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_world_size,
@@ -46,8 +42,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
get_pcp_group, get_pp_group,
get_tp_group,
is_global_first_rank)
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -58,8 +53,7 @@ from vllm.utils.import_utils import LazyLoader
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
@@ -1972,13 +1966,18 @@ class NPUModelRunner(GPUModelRunner):
self,
num_tokens: int,
with_prefill: bool = False,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
uniform_decode: bool = False,
is_profile: bool = False,
allow_microbatching: bool = True,
skip_eplb: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
) -> torch.Tensor:
# only support eager mode and piecewise graph now
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
@@ -2054,15 +2053,15 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)
# filter out the valid batch descriptor
if aclgraph_runtime_mode is not None:
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for aclgraph capture
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode:
if cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode != _ag_mode:
raise ValueError(
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
f"Expected {_ag_mode}, but got {cudagraph_runtime_mode}.")
else:
aclgraph_runtime_mode = _ag_mode
cudagraph_runtime_mode = _ag_mode
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
# and not supported in ASCEND now. We could remove it in the future.
@@ -2071,7 +2070,7 @@ class NPUModelRunner(GPUModelRunner):
num_reqs=num_reqs_padded,
num_tokens=num_tokens_padded,
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
aclgraph_runtime_mode=cudagraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)
@@ -2147,7 +2146,7 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=is_profile,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
aclgraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states(
@@ -2161,7 +2160,7 @@ class NPUModelRunner(GPUModelRunner):
with_prefill=with_prefill,
num_reqs=num_reqs_padded,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
aclgraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits,
in_graph_capturing=not force_attention,
@@ -2677,7 +2676,8 @@ class NPUModelRunner(GPUModelRunner):
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
) -> tuple[dict[AttentionGroupKey, list[str]],
set[type[AttentionBackend]]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase,
kv_cache_group_spec.layer_names)
@@ -2699,10 +2699,14 @@ class NPUModelRunner(GPUModelRunner):
attn_backends[key] = AttentionGroupKey(attn_backend,
layer_kv_cache_spec)
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
return (
{
attn_backends[k]: v
for k, v in attn_backend_layers.items()
},
set(group_key.attn_backend
for group_key in attn_backends.values()),
)
def create_attn_groups(attn_backends_map: dict[AttentionBackend,
list[str]],
@@ -2723,11 +2727,21 @@ class NPUModelRunner(GPUModelRunner):
attn_groups.append(attn_group)
return attn_groups
attention_backend_maps = []
attention_backend_list = []
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
attention_backend_maps.append(attn_backends[0])
attention_backend_list.append(attn_backends[1])
self._check_and_update_cudagraph_mode(attention_backend_list,
kv_cache_config.kv_cache_groups)
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
attn_backends = get_attn_backends_for_group( # type: ignore
kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends, i))
self.attn_groups.append(create_attn_groups(attn_backends[0], i))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
@@ -2855,214 +2869,26 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec
def initialize_aclgraph_capture(self) -> None:
min_ag_support = AttentionCGSupport.ALWAYS
min_ag_builder_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.get_metadata_builder()
graph_support = None
if hasattr(builder, 'aclgraph_support'):
graph_support = builder.aclgraph_support.value
builder_aclgraph = builder.aclgraph_support
else:
graph_support = builder._cudagraph_support.value
builder_aclgraph = builder._cudagraph_support
if graph_support < min_ag_support.value:
min_ag_support = builder_aclgraph
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
and all(op in self.compilation_config.splitting_ops for op in [
"vllm.mla_forward",
]))
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_ag_support != AttentionCGSupport.ALWAYS:
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
f"with {min_ag_builder_name} backend (support: "
f"{min_ag_support})")
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \
and min_ag_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
f"supported with {min_ag_builder_name} backend ("
f"support:{min_ag_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and aclgraph_mode.separate_routine()
and self.uniform_decode_query_len > 1):
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len,
self.parallel_config.tensor_parallel_size)
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self.cudagraph_batch_sizes = (capture_sizes
if capture_sizes is not None else [])
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec],
) -> None:
super()._check_and_update_cudagraph_mode(attention_backends,
kv_cache_groups)
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
# we set the graph params right before initializing the keys.
set_graph_params(self.cudagraph_batch_sizes)
if self.speculative_config:
set_draft_graph_params(self.cudagraph_batch_sizes)
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def _capture_aclgraphs(self, compilation_cases: list[int],
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
"mode: %s, uniform_decode: %s", compilation_cases,
aclgraph_runtime_mode.name, uniform_decode)
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing ACL graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
aclgraph_runtime_mode.name))
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
if not self.get_kv_cache_spec():
self._dummy_run(2,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
self._dummy_run(num_tokens,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
uniform_decode=uniform_decode)
def _capture_model(self):
if not self.use_aclgraph:
logger.warning(
"Skipping ACL graph capture. To turn on ACL graph capture, "
"ensure `aclraph_mode` was not manually set to `NONE`")
return
else:
self.initialize_aclgraph_capture()
set_cudagraph_capturing_enabled(True)
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
aclgraph_mode = self.compilation_config.cudagraph_mode
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
try:
self._capture_aclgraphs(
compilation_cases,
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False)
except Exception as e:
error_msg = str(e)
error_code = '0x7020023'
pattern = r'retCode=([^,\s\.]+)'
match = re.search(pattern, error_msg)
if match:
retCode = match.group(1)
# Determine whether the error message is caused by stream capture failure.
if match and retCode == error_code:
logger.error(
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
"Recommended solutions:\n"
"1. Manually configure the compilation_config parameter "
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
f"{str(e)}")
raise
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.cudagraph_batch_sizes if
x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
if self.use_aclgraph:
set_graph_params(self.cudagraph_batch_sizes)
if self.speculative_config:
set_draft_graph_params(self.cudagraph_batch_sizes)
def capture_model(self) -> None:
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
self._capture_model()
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
npu_graph_size = start_free_npu_memory - end_free_npu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
parent_module_name = self.__class__.__base__.__module__
with _torch_cuda_wrapper(), _replace_gpu_model_runner_function_wrapper(
parent_module_name):
super().capture_model()
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
@@ -3473,6 +3299,8 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.stream = torch.npu.stream
torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.mem_get_info = torch.npu.mem_get_info
yield
except Exception:
torch.cuda.Event = _EventPlaceholder
@@ -3480,6 +3308,8 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = _StreamPlaceholder
torch.cuda.current_stream = _StreamPlaceholder
torch.cuda.stream = _StreamPlaceholder
torch.cuda.synchronize = _StreamPlaceholder
torch.cuda.mem_get_info = _StreamPlaceholder
finally:
# if anything goes wrong, just patch it with a placeholder
torch.cuda.Event = _EventPlaceholder
@@ -3487,3 +3317,16 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.stream = torch.npu.stream
torch.cuda.synchronize = torch.npu.synchronize
torch.cuda.mem_get_info = torch.npu.mem_get_info
# TODO: This method will be removed subsequently and implemented in platform.
@contextmanager
def _replace_gpu_model_runner_function_wrapper(target_module_name):
try:
target_module = sys.modules[target_module_name]
setattr(target_module, "graph_capture", graph_capture)
yield
finally:
setattr(target_module, "graph_capture", graph_capture)

View File

@@ -46,6 +46,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
@@ -231,6 +232,9 @@ class NPUWorker(WorkerBase):
# in ray scenario. see https://github.com/vllm-project/vllm/pull/26845
# for more details
self.device = self._init_device()
# Initialize workspace manager
num_ubatches = 1
init_workspace_manager(self.device, num_ubatches)
# Init ModelRunner here, so that we have access to self.device.
if self.use_v2_model_runner:
logger.warning(