[Aclgraph] Update compilation config in check_and_update_config (#2540)

### What this PR does / why we need it?
This pr updates compilation config in `check_and_update_config`, we use
`compilation_config.level` to update `compilation_config.cudagraph_mode`
to ensure the config is correct.

Add `compilation_config.cudagraph_num_of_warmups = 1` when V1 is
enabled, cause this is also used in torchair graph mode. and this fixes
https://github.com/vllm-project/vllm-ascend/issues/2523

fix the bug that the `aclgraphmode` always be `NONE` while running
forward in aclgraph mode

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.10.1.1
- vLLM main:
f58675bfb3

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-08-27 09:30:25 +08:00
committed by GitHub
parent f22077daa6
commit a9e78a3299
3 changed files with 118 additions and 34 deletions

View File

@@ -3,11 +3,11 @@ import unittest
from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.config import CompilationLevel
from vllm.config.compilation import CUDAGraphMode
from vllm.platforms import PlatformEnum
from tests.ut.base import TestBase
@@ -28,6 +28,7 @@ class TestNPUPlatform(TestBase):
self.mock_vllm_config.scheduler_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
self.mock_vllm_config.compilation_config.cudagraph_mode = None
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.torchair_graph_config.enabled = False
@@ -269,8 +270,6 @@ class TestNPUPlatform(TestBase):
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("Model config is missing" in cm.output[0])
@pytest.mark.skip(
reason="TODO: revert me when the occasional failed is fixed")
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@@ -290,6 +289,10 @@ class TestNPUPlatform(TestBase):
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
self.mock_vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@@ -310,6 +313,64 @@ class TestNPUPlatform(TestBase):
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
self.mock_vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_unsupported_cudagraph_mode(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config.enforce_eager = False
self.mock_vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue(
"cudagraph_mode is not support on NPU. falling back to NONE" in
cm.output[0])
self.assertEqual(
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
self.mock_vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_disable_aclgraph_when_ray_enabled(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config.enforce_eager = False
self.mock_vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
self.mock_vllm_config.parallel_config.distributed_executor_backend = "ray"
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
print(30 * "=", f"cm.output: {cm.output}")
self.assertTrue(
"Ray distributed executor backend is not compatible with ACL Graph mode"
in cm.output[0])
self.assertEqual(
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
self.mock_vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@@ -331,6 +392,10 @@ class TestNPUPlatform(TestBase):
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
self.mock_vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")

View File

@@ -140,52 +140,65 @@ class NPUPlatform(Platform):
check_ascend_config(vllm_config, enforce_eager)
from vllm.config.compilation import CUDAGraphMode
# TODO(cmq): update the post init in vllmconfig
# if cudagraph_mode is not explicitly set by users, set default value
if envs_vllm.VLLM_USE_V1 and compilation_config.level \
== CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
vllm_config._set_cudagraph_sizes()
# TODO(cmq): update the compilation level config to be determined by CUDAGraphMode
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
if enforce_eager:
logger.info("Compilation disabled, using eager mode by default")
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif compilation_config.level != CompilationLevel.PIECEWISE:
logger.warning(
"NPU does not support %s compilation level. Setting level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif ascend_config.torchair_graph_config.enabled:
compilation_config.cudagraph_num_of_warmups = 1
if compilation_config.cudagraph_mode is None:
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
logger.info(
"Torchair compilation enabled on NPU. Setting level to NO_COMPILATION"
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
)
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif parallel_config.distributed_executor_backend == "ray":
if parallel_config.distributed_executor_backend == "ray":
logger.warning(
"Ray distributed executor backend is not compatible with ACL Graph mode "
"right now. Setting level to NO_COMPILATION")
compilation_config.level = CompilationLevel.NO_COMPILATION
"right now. Setting CUDAGraphMode to NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
# set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes()
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
if envs_vllm.VLLM_USE_V1 and \
compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.set_splitting_ops_for_v1()
assert compilation_config.level == CompilationLevel.PIECEWISE, \
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1()
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
update_aclgraph_sizes(vllm_config)
compilation_config.cudagraph_num_of_warmups = 1
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
compilation_config.cudagraph_mode)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.level = CompilationLevel.NO_COMPILATION
if parallel_config and parallel_config.worker_cls == "auto":
if ascend_config.torchair_graph_config.enabled:

View File

@@ -1660,6 +1660,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method = (self.moe_comm_method
if num_input_tokens <= self.mc2_tokens_capacity else
self.fallback_moe_comm_method)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(batch_descriptor)
# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
@@ -1671,6 +1675,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens):
self.maybe_setup_kv_connector(scheduler_output)