Files
xc-llm-ascend/tests/ut/distributed/test_parallel_state.py
wangyanhui-cmss e5eea64b66 [CI/UT] Add ut for parallel_state.py (#1460)
### What this PR does / why we need it?
 Add ut for parallel_state.py

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
 python -m unittest  test_parallel_state.py

---------

Signed-off-by: wangyanhui-cmss <wangyanhui_yewu@cmss.chinamobile.com>
2025-06-26 19:03:27 +08:00

194 lines
8.5 KiB
Python

import unittest
from unittest.mock import MagicMock, patch
import pytest
from vllm.distributed.parallel_state import GroupCoordinator
import vllm_ascend
from vllm_ascend.distributed.parallel_state import (
destory_ascend_model_parallel, get_ep_group, get_etp_group,
init_ascend_model_parallel, model_parallel_initialized)
class TestParallelState(unittest.TestCase):
@patch('vllm_ascend.distributed.parallel_state._EP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_get_ep_group_when_initialized(self, mock_ep):
# Act
result = get_ep_group()
# Assert
assert isinstance(result, GroupCoordinator)
@patch('vllm_ascend.distributed.parallel_state._EP', None)
def test_get_ep_group_when_not_initialized(self):
# Act & Assert
with pytest.raises(AssertionError) as excinfo:
get_ep_group()
assert "expert model parallel group is not initialized" in str(
excinfo.value)
@patch('vllm_ascend.distributed.parallel_state._ETP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_get_etp_group_when_initialized(self, mock_etp):
# Act
result = get_etp_group()
# Assert
assert isinstance(result, GroupCoordinator)
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
def test_get_etp_group_when_not_initialized(self):
# Act & Assert
with pytest.raises(AssertionError) as excinfo:
get_etp_group()
assert "expert tensor parallel group is not initialized" in str(
excinfo.value)
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
@patch('vllm_ascend.distributed.parallel_state._EP', None)
def test_model_parallel_initialized_when_both_none(self):
# Act & Assert
assert not model_parallel_initialized()
@patch('vllm_ascend.distributed.parallel_state._ETP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.distributed.parallel_state._EP', None)
def test_model_parallel_initialized_when_ep_none(self, mock_etp):
# Act & Assert
assert not model_parallel_initialized()
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
@patch('vllm_ascend.distributed.parallel_state._EP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_model_parallel_initialized_when_etp_none(self, mock_ep):
# Act & Assert
assert not model_parallel_initialized()
@patch('vllm_ascend.distributed.parallel_state._ETP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.distributed.parallel_state._EP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_model_parallel_initialized_when_etp_initialized(
self, mock_ep, mock_etp):
# Act & Assert
assert model_parallel_initialized()
@patch('vllm_ascend.distributed.parallel_state._ETP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.distributed.parallel_state._EP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_destroy_when_both_exist(self, mock_ep, mock_etp):
# Act
destory_ascend_model_parallel()
# Assert
mock_ep.destroy.assert_called_once()
mock_etp.destroy.assert_called_once()
assert vllm_ascend.distributed.parallel_state._ETP is None
assert vllm_ascend.distributed.parallel_state._EP is None
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
@patch('vllm_ascend.distributed.parallel_state._EP',
new_callable=lambda: MagicMock())
def test_destory_ascend_model_parallel_when_etp_none(self, mock_ep):
# Act
destory_ascend_model_parallel()
# Assert
mock_ep.destroy.assert_called_once()
assert vllm_ascend.distributed.parallel_state._EP is None
assert vllm_ascend.distributed.parallel_state._ETP is None
@patch('vllm_ascend.distributed.parallel_state._ETP',
new_callable=lambda: MagicMock())
@patch('vllm_ascend.distributed.parallel_state._EP', None)
def test_destory_ascend_model_parallel_when_ep_none(self, mock_etp):
# Act
destory_ascend_model_parallel()
# Assert
mock_etp.destroy.assert_called_once()
assert vllm_ascend.distributed.parallel_state._ETP is None
assert vllm_ascend.distributed.parallel_state._EP is None
@patch('vllm_ascend.distributed.parallel_state._ETP', None)
@patch('vllm_ascend.distributed.parallel_state._EP', None)
def test_destory_ascend_model_parallel_when_both_none(self):
# Act
destory_ascend_model_parallel()
# Assert
assert vllm_ascend.distributed.parallel_state._ETP is None
assert vllm_ascend.distributed.parallel_state._EP is None
@patch('torch.distributed.is_initialized', return_value=True)
@patch('torch.distributed.get_world_size', return_value=8)
@patch('vllm_ascend.distributed.parallel_state.get_world_group',
return_value=MagicMock(device_group='npu:0', local_rank=0))
@patch('torch.distributed.get_backend', return_value='hccl')
@patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group')
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
return_value=False)
def test_init_ascend_model_parallel_normal_case(
self, mock_mp_init, mock_init_group, mock_get_backend,
mock_world_group, mock_get_world_size, mock_is_init):
"""Test normal initialization with default parameters"""
# Act
init_ascend_model_parallel()
# Assert
mock_init_group.assert_any_call([[0, 1, 2, 3, 4, 5, 6, 7]],
0,
'hccl',
group_name="ep")
mock_init_group.assert_any_call([[0]], 0, 'hccl', group_name="etp")
self.assertIsNotNone(vllm_ascend.distributed.parallel_state._EP)
self.assertIsNotNone(vllm_ascend.distributed.parallel_state._ETP)
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
return_value=True)
def test_init_ascend_model_parallel_skip_if_initialized(
self, mock_mp_init):
"""Test skipping when model parallel already initialized"""
with patch.object(vllm_ascend.distributed.parallel_state,
'_EP') as mock_ep, patch.object(
vllm_ascend.distributed.parallel_state,
'_ETP') as mock_etp:
# Act
init_ascend_model_parallel()
# Assert
mock_ep.assert_not_called()
mock_etp.assert_not_called()
@patch('torch.distributed.is_initialized', return_value=False)
def test_init_ascend_model_parallel_assert_dist_not_init(
self, mock_is_init):
"""Test assertion when distributed not initialized"""
# Act & Assert
with self.assertRaises(AssertionError):
init_ascend_model_parallel()
@patch('torch.distributed.is_initialized', return_value=True)
@patch('torch.distributed.get_world_size', return_value=8)
@patch('vllm_ascend.distributed.parallel_state.get_world_group',
return_value=MagicMock(device_group='npu:0', local_rank=1))
@patch('torch.distributed.get_backend', return_value='hccl')
@patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group')
@patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized',
return_value=False)
def test_init_ascend_model_parallel_custom_params(
self, mock_mp_init, mock_init_group, mock_get_backend,
mock_world_group, mock_get_world_size, mock_is_init):
"""Test initialization with custom parallel sizes"""
# Act
init_ascend_model_parallel(expert_parallel_size=2,
expert_tensor_parallel_size=4,
world_size=8,
backend='hccl')
#Assert
mock_init_group.assert_any_call([[0, 4], [1, 5], [2, 6], [3, 7]],
1,
'hccl',
group_name="ep")
mock_init_group.assert_any_call([[0, 1, 2, 3], [4, 5, 6, 7]],
1,
'hccl',
group_name="etp")