Files
xc-llm-ascend/tests/ut/distributed/test_parallel_state.py
Levi 0a62e671fb [Feat] flashcomm_v2 optim solution (#3232)
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
2025-11-10 11:01:45 +08:00

76 lines
3.3 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
get_otp_group, get_p_tp_group, init_ascend_model_parallel)
@pytest.fixture
def parallel_config():
return ParallelConfig(data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=2)
@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
mock_dp_group.return_value.world_size = 2
yield
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.oproj_tensor_parallel_size = 2
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2
mock_ascend_config.num_head_replica = 0
mock_ascend_config.pd_head_ratio = 2
mock_vllm_config = MagicMock()
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
lmheadtp_group = get_lmhead_tp_group()
otp_group = get_otp_group()
flashcomm2_otp_group = get_flashcomm2_otp_group()
flashcomm2_odp_group = get_flashcomm2_odp_group()
p_tp_group = get_p_tp_group()
assert mc2_group is not None
assert otp_group is not None
assert flashcomm2_otp_group is not None
assert flashcomm2_odp_group is not None
assert lmheadtp_group is not None
assert p_tp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
assert _OTP is None
assert _FLASHCOMM2_OTP is None
assert _FLASHCOMM2_ODP is None
assert _P_TP is None