[Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new ut (#2511)

[Bugfix]Support Qwen3-MOE on aclgraph mode in sizes capture and add new
ut

What this PR does / why we need it?
This PR solves the problem of sizes capture and stream error caused by
using ACLgraph on the Qwen3-30B MOE model.
Add new ut.

Does this PR introduce any user-facing change?
no

How was this patch tested?
ut

- vLLM version: v0.10.1.1
- vLLM main:
6fad29b11b

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2025-08-26 12:39:21 +08:00
committed by GitHub
parent b3fdd78a6b
commit cfe77e83ae
3 changed files with 80 additions and 7 deletions

View File

@@ -21,6 +21,8 @@
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
""" """
import os
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
from tests.e2e.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
@@ -72,3 +74,36 @@ def test_models_distributed_Qwen3_MOE_W8A8():
enforce_eager=False, enforce_eager=False,
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH_AIV():
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
example_prompts = [
"Hello, my name is",
]
dtype = "auto"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -255,6 +255,9 @@ class TestUtils(TestBase):
parallel_config=test_parallel_config, parallel_config=test_parallel_config,
) )
utils.update_aclgraph_sizes(test_vllm_config) utils.update_aclgraph_sizes(test_vllm_config)
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
utils.update_aclgraph_sizes(test_vllm_config)
del os.environ['HCCL_OP_EXPANSION_MODE']
self.assertEqual( self.assertEqual(
147, 147,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -267,6 +270,9 @@ class TestUtils(TestBase):
parallel_config=test_parallel_config, parallel_config=test_parallel_config,
) )
utils.update_aclgraph_sizes(test_vllm_config) utils.update_aclgraph_sizes(test_vllm_config)
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
utils.update_aclgraph_sizes(test_vllm_config)
del os.environ['HCCL_OP_EXPANSION_MODE']
self.assertEqual( self.assertEqual(
3, 3,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) len(test_vllm_config.compilation_config.cudagraph_capture_sizes))

View File

@@ -20,6 +20,7 @@
import atexit import atexit
import functools import functools
import math import math
import os
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
@@ -303,16 +304,47 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
# TODO: Find out whether we need to take into account the pp_size # TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + sum(size > 1 for size in [ num_comm_groups = sum(size > 1 for size in [
parallel_config.data_parallel_size_local, parallel_config.data_parallel_size,
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
]) ])
# Calculate maximum supported batch sizes considering model architecture if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / # TODO: Find out whether we need to take into account the pp_size
(num_hidden_layers + 1) / parallel_factor) parallel_factor = 1 + num_comm_groups + int(
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", parallel_config.enable_expert_parallel)
max_num_batch_sizes) # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
max_num_batch_sizes = math.floor(
MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor)
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
else:
# The above describes an empirical formula applicable to the A2 hardware.
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
# On A3 hardware, HCCL defaults to the AICPU method.
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
# Therefore, the calculation formula has been modified as follows:
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
max_num_batch_sizes = math.floor(
(MAX_CAPTURE_SIZE - num_comm_groups * 40) /
(num_hidden_layers + 1) / (1 + num_comm_groups * 2))
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
logger.warning(
"Currently, communication is performed using FFTS+ method, which reduces "
"the number of available streams and, as a result, limits the range of runtime "
"shapes that can be handled. To both improve communication performance and "
"increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV."
)
# If original sizes exceed maximum, sample a representative subset # If original sizes exceed maximum, sample a representative subset
if max_num_batch_sizes < len(original_sizes): if max_num_batch_sizes < len(original_sizes):