From 6c973361fc2eba5d3faa9b6b496b4b9fec4dc784 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Thu, 28 Aug 2025 14:08:31 +0800 Subject: [PATCH] [Bugfix] Fix aclgraph not enabled by default (#2590) ### What this PR does / why we need it? As vllm will set `cudagraph_mode` to `NONE` before `check_and_update_config` in post init of `VllmConfig` (https://github.com/vllm-project/vllm/blob/5da4f5d857933329aaca779e3a81f1385c84e34a/vllm/config/__init__.py#L3630), we always have `cudagraph_mode` isn't `None`, thus we must remove this check and add it when the related adaption in vllm is done. part of https://github.com/vllm-project/vllm-ascend/pull/2577, will add the e2e test on applying reply after the CI refactor is done ### How was this patch tested? CI passed with existing test. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/f48a9af8924ea617a964b1158acc142b64843edb Signed-off-by: MengqingCao --- tests/ut/test_platform.py | 3 +++ vllm_ascend/compilation/acl_graph.py | 5 ++-- vllm_ascend/platform.py | 34 ++++++++++++++-------------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 551f1d0..bd07602 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -3,6 +3,7 @@ import unittest from datetime import timedelta from unittest.mock import MagicMock, patch +import pytest import torch from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import PrefixStore @@ -318,6 +319,8 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) + @pytest.mark.skip( + "Revert me when vllm support setting cudagraph_mode on oot platform") @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 6f187e2..f8dfc24 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -13,12 +13,10 @@ from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, get_forward_context -from vllm.logger import init_logger +from vllm.logger import logger from vllm.platforms import current_platform from vllm.utils import weak_ref_tensors -logger = init_logger(__name__) - @dataclasses.dataclass class ACLGraphEntry: @@ -182,5 +180,6 @@ class ACLGraphWrapper: f"during replay. Expected {entry.input_addresses}, " f"got {new_input_addresses}") + logger.info_once("Replaying aclgraph") entry.aclgraph.replay() return entry.output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b8e1039..afc0e6b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -146,23 +146,23 @@ class NPUPlatform(Platform): compilation_config.cudagraph_num_of_warmups = 1 - if compilation_config.cudagraph_mode is None: - # if cudagraph_mode is not explicitly set by users, set default value - if compilation_config.level == CompilationLevel.PIECEWISE: - compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - elif compilation_config.level not in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE - ]: - logger.warning( - "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", - compilation_config.level) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - else: - logger.warning( - "compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE" - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode` + # if cudagraph_mode is not explicitly set by users, set default value + if compilation_config.level == CompilationLevel.PIECEWISE: + compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + elif compilation_config.level not in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE + ]: + logger.warning( + "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", + compilation_config.level) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + else: + logger.warning( + "compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE" + ) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. if ascend_config.torchair_graph_config.enabled: