2025-08-29 11:41:21 +08:00
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from vllm.config import ParallelConfig
|
|
|
|
|
|
|
|
|
|
from vllm_ascend.distributed.parallel_state import (
|
2025-11-10 11:01:45 +08:00
|
|
|
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
|
|
|
|
|
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
|
|
|
|
|
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
|
|
|
|
|
get_otp_group, get_p_tp_group, init_ascend_model_parallel)
|
2025-08-29 11:41:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def parallel_config():
|
2025-12-12 14:41:20 +08:00
|
|
|
return ParallelConfig(
|
|
|
|
|
data_parallel_size=2,
|
|
|
|
|
tensor_parallel_size=4,
|
|
|
|
|
pipeline_parallel_size=2,
|
|
|
|
|
)
|
2025-08-29 11:41:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_distributed():
|
|
|
|
|
with patch('torch.distributed.is_initialized', return_value=True), \
|
2025-12-12 14:41:20 +08:00
|
|
|
patch('torch.distributed.get_world_size', return_value=16), \
|
2025-08-29 11:41:21 +08:00
|
|
|
patch('torch.distributed.get_backend', return_value='nccl'), \
|
2025-11-10 11:01:45 +08:00
|
|
|
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
2026-01-08 09:05:02 +08:00
|
|
|
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
|
2025-08-29 11:41:21 +08:00
|
|
|
mock_group.return_value.local_rank = 0
|
|
|
|
|
mock_group.return_value.device_group = MagicMock()
|
2025-11-10 11:01:45 +08:00
|
|
|
mock_tp_group.return_value.world_size = 4
|
2025-08-29 11:41:21 +08:00
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
|
|
|
|
mock_ascend_config = MagicMock()
|
2025-12-12 14:41:20 +08:00
|
|
|
mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2
|
|
|
|
|
mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2
|
|
|
|
|
mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2
|
2025-12-18 20:06:53 +08:00
|
|
|
mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2
|
2025-11-10 11:01:45 +08:00
|
|
|
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
|
2025-09-30 15:10:29 +08:00
|
|
|
mock_ascend_config.pd_tp_ratio = 2
|
2025-10-11 11:22:23 +08:00
|
|
|
mock_ascend_config.num_head_replica = 0
|
|
|
|
|
mock_ascend_config.pd_head_ratio = 2
|
|
|
|
|
mock_vllm_config = MagicMock()
|
|
|
|
|
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
2025-11-10 11:01:45 +08:00
|
|
|
mock_envs_ascend = MagicMock()
|
|
|
|
|
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
|
|
|
|
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
2025-08-29 11:41:21 +08:00
|
|
|
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
|
|
|
|
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
2025-10-11 11:22:23 +08:00
|
|
|
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
|
2025-11-10 11:01:45 +08:00
|
|
|
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
|
|
|
|
|
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
|
|
|
|
|
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
|
2025-08-29 11:41:21 +08:00
|
|
|
init_ascend_model_parallel(parallel_config)
|
|
|
|
|
|
|
|
|
|
mc2_group = get_mc2_group()
|
|
|
|
|
lmheadtp_group = get_lmhead_tp_group()
|
2025-09-07 10:31:32 +08:00
|
|
|
otp_group = get_otp_group()
|
2025-11-10 11:01:45 +08:00
|
|
|
flashcomm2_otp_group = get_flashcomm2_otp_group()
|
|
|
|
|
flashcomm2_odp_group = get_flashcomm2_odp_group()
|
2025-09-30 15:10:29 +08:00
|
|
|
p_tp_group = get_p_tp_group()
|
2025-09-07 10:31:32 +08:00
|
|
|
assert mc2_group is not None
|
|
|
|
|
assert otp_group is not None
|
2025-11-10 11:01:45 +08:00
|
|
|
assert flashcomm2_otp_group is not None
|
|
|
|
|
assert flashcomm2_odp_group is not None
|
2025-08-29 11:41:21 +08:00
|
|
|
assert lmheadtp_group is not None
|
2025-09-30 15:10:29 +08:00
|
|
|
assert p_tp_group is not None
|
2025-08-29 11:41:21 +08:00
|
|
|
|
|
|
|
|
destroy_ascend_model_parallel()
|
|
|
|
|
assert _MC2 is None
|
|
|
|
|
assert _LMTP is None
|
2025-09-07 10:31:32 +08:00
|
|
|
assert _OTP is None
|
2025-11-10 11:01:45 +08:00
|
|
|
assert _FLASHCOMM2_OTP is None
|
|
|
|
|
assert _FLASHCOMM2_ODP is None
|
2025-09-30 15:10:29 +08:00
|
|
|
assert _P_TP is None
|