From cfe77e83aeda343274c0488b93e2263bee44a860 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Tue, 26 Aug 2025 12:39:21 +0800 Subject: [PATCH] [Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new ut (#2511) [Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new ut What this PR does / why we need it? This PR solves the problem of sizes capture and stream error caused by using ACLgraph on the Qwen3-30B MOE model. Add new ut. Does this PR introduce any user-facing change? no How was this patch tested? ut - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/6fad29b11b3680c44782cd6e5fe555779d620d6c Signed-off-by: lilinsiman --- tests/e2e/multicard/test_qwen3_moe.py | 35 ++++++++++++++++++++ tests/ut/test_utils.py | 6 ++++ vllm_ascend/utils.py | 46 +++++++++++++++++++++++---- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py index dcac7a8..5dfe36a 100644 --- a/tests/e2e/multicard/test_qwen3_moe.py +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -21,6 +21,8 @@ Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. """ +import os + from modelscope import snapshot_download # type: ignore from tests.e2e.conftest import VllmRunner @@ -72,3 +74,36 @@ def test_models_distributed_Qwen3_MOE_W8A8(): enforce_eager=False, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV(): + os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV' + example_prompts = [ + "Hello, my name is", + ] + dtype = "auto" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=False, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + example_prompts = [ + "Hello, my name is", + ] + dtype = "auto" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=False, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) \ No newline at end of file diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 73eca32..8166e17 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -255,6 +255,9 @@ class TestUtils(TestBase): parallel_config=test_parallel_config, ) utils.update_aclgraph_sizes(test_vllm_config) + os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV' + utils.update_aclgraph_sizes(test_vllm_config) + del os.environ['HCCL_OP_EXPANSION_MODE'] self.assertEqual( 147, len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) @@ -267,6 +270,9 @@ class TestUtils(TestBase): parallel_config=test_parallel_config, ) utils.update_aclgraph_sizes(test_vllm_config) + os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV' + utils.update_aclgraph_sizes(test_vllm_config) + del os.environ['HCCL_OP_EXPANSION_MODE'] self.assertEqual( 3, len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index cd1e118..3d7ed29 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -20,6 +20,7 @@ import atexit import functools import math +import os from contextlib import contextmanager from enum import Enum from threading import Lock @@ -303,16 +304,47 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config # TODO: Find out whether we need to take into account the pp_size - parallel_factor = 1 + sum(size > 1 for size in [ - parallel_config.data_parallel_size_local, + num_comm_groups = sum(size > 1 for size in [ + parallel_config.data_parallel_size, parallel_config.tensor_parallel_size, ]) - # Calculate maximum supported batch sizes considering model architecture - max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / - (num_hidden_layers + 1) / parallel_factor) - logger.info("Calculated maximum supported batch sizes for ACL graph: %s", - max_num_batch_sizes) + if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': + # TODO: Find out whether we need to take into account the pp_size + parallel_factor = 1 + num_comm_groups + int( + parallel_config.enable_expert_parallel) + # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device + # Assume the following case: + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, + # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 + max_num_batch_sizes = math.floor( + MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor) + logger.info( + "Calculated maximum supported batch sizes for ACL graph: %s", + max_num_batch_sizes) + else: + # The above describes an empirical formula applicable to the A2 hardware. + # Under this configuration, HCCL employs the FFTS+ method for execution unfolding, + # which adds only 1 concurrent stream without consuming collective communication execution unfolding streams. + # On A3 hardware, HCCL defaults to the AICPU method. + # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case). + # Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes. + # Therefore, the calculation formula has been modified as follows: + # Assume the following case: + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, + # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 + max_num_batch_sizes = math.floor( + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / + (num_hidden_layers + 1) / (1 + num_comm_groups * 2)) + logger.info( + "Calculated maximum supported batch sizes for ACL graph: %s", + max_num_batch_sizes) + logger.warning( + "Currently, communication is performed using FFTS+ method, which reduces " + "the number of available streams and, as a result, limits the range of runtime " + "shapes that can be handled. To both improve communication performance and " + "increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV." + ) # If original sizes exceed maximum, sample a representative subset if max_num_batch_sizes < len(original_sizes):