# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # from unittest.mock import MagicMock, patch import pytest from vllm.distributed.parallel_state import GroupCoordinator import vllm_ascend from tests.ut.base import TestBase from vllm_ascend.distributed.parallel_state import ( destory_ascend_model_parallel, get_ep_group, get_etp_group, init_ascend_model_parallel, model_parallel_initialized) class TestParallelState(TestBase): @patch('vllm_ascend.distributed.parallel_state._EP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_get_ep_group_when_initialized(self, mock_ep): # Act result = get_ep_group() # Assert assert isinstance(result, GroupCoordinator) @patch('vllm_ascend.distributed.parallel_state._EP', None) def test_get_ep_group_when_not_initialized(self): # Act & Assert with pytest.raises(AssertionError) as excinfo: get_ep_group() assert "expert model parallel group is not initialized" in str( excinfo.value) @patch('vllm_ascend.distributed.parallel_state._ETP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_get_etp_group_when_initialized(self, mock_etp): # Act result = get_etp_group() # Assert assert isinstance(result, GroupCoordinator) @patch('vllm_ascend.distributed.parallel_state._ETP', None) def test_get_etp_group_when_not_initialized(self): # Act & Assert with pytest.raises(AssertionError) as excinfo: get_etp_group() assert "expert tensor parallel group is not initialized" in str( excinfo.value) @patch('vllm_ascend.distributed.parallel_state._ETP', None) @patch('vllm_ascend.distributed.parallel_state._EP', None) def test_model_parallel_initialized_when_both_none(self): # Act & Assert assert not model_parallel_initialized() @patch('vllm_ascend.distributed.parallel_state._ETP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm_ascend.distributed.parallel_state._EP', None) def test_model_parallel_initialized_when_ep_none(self, mock_etp): # Act & Assert assert not model_parallel_initialized() @patch('vllm_ascend.distributed.parallel_state._ETP', None) @patch('vllm_ascend.distributed.parallel_state._EP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_model_parallel_initialized_when_etp_none(self, mock_ep): # Act & Assert assert not model_parallel_initialized() @patch('vllm_ascend.distributed.parallel_state._ETP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm_ascend.distributed.parallel_state._EP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_model_parallel_initialized_when_etp_initialized( self, mock_ep, mock_etp): # Act & Assert assert model_parallel_initialized() @patch('vllm_ascend.distributed.parallel_state._ETP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm_ascend.distributed.parallel_state._EP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_destroy_when_both_exist(self, mock_ep, mock_etp): # Act destory_ascend_model_parallel() # Assert mock_ep.destroy.assert_called_once() mock_etp.destroy.assert_called_once() assert vllm_ascend.distributed.parallel_state._ETP is None assert vllm_ascend.distributed.parallel_state._EP is None @patch('vllm_ascend.distributed.parallel_state._ETP', None) @patch('vllm_ascend.distributed.parallel_state._EP', new_callable=lambda: MagicMock()) def test_destory_ascend_model_parallel_when_etp_none(self, mock_ep): # Act destory_ascend_model_parallel() # Assert mock_ep.destroy.assert_called_once() assert vllm_ascend.distributed.parallel_state._EP is None assert vllm_ascend.distributed.parallel_state._ETP is None @patch('vllm_ascend.distributed.parallel_state._ETP', new_callable=lambda: MagicMock()) @patch('vllm_ascend.distributed.parallel_state._EP', None) def test_destory_ascend_model_parallel_when_ep_none(self, mock_etp): # Act destory_ascend_model_parallel() # Assert mock_etp.destroy.assert_called_once() assert vllm_ascend.distributed.parallel_state._ETP is None assert vllm_ascend.distributed.parallel_state._EP is None @patch('vllm_ascend.distributed.parallel_state._ETP', None) @patch('vllm_ascend.distributed.parallel_state._EP', None) def test_destory_ascend_model_parallel_when_both_none(self): # Act destory_ascend_model_parallel() # Assert assert vllm_ascend.distributed.parallel_state._ETP is None assert vllm_ascend.distributed.parallel_state._EP is None @patch('torch.distributed.is_initialized', return_value=True) @patch('torch.distributed.get_world_size', return_value=8) @patch('vllm_ascend.distributed.parallel_state.get_world_group', return_value=MagicMock(device_group='npu:0', local_rank=0)) @patch('torch.distributed.get_backend', return_value='hccl') @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False) def test_init_ascend_model_parallel_normal_case( self, mock_mp_init, mock_init_group, mock_get_backend, mock_world_group, mock_get_world_size, mock_is_init): """Test normal initialization with default parameters""" # Act init_ascend_model_parallel() # Assert mock_init_group.assert_any_call([[0, 1, 2, 3, 4, 5, 6, 7]], 0, 'hccl', group_name="ep") mock_init_group.assert_any_call([[0]], 0, 'hccl', group_name="etp") self.assertIsNotNone(vllm_ascend.distributed.parallel_state._EP) self.assertIsNotNone(vllm_ascend.distributed.parallel_state._ETP) @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=True) def test_init_ascend_model_parallel_skip_if_initialized( self, mock_mp_init): """Test skipping when model parallel already initialized""" with patch.object(vllm_ascend.distributed.parallel_state, '_EP') as mock_ep, patch.object( vllm_ascend.distributed.parallel_state, '_ETP') as mock_etp: # Act init_ascend_model_parallel() # Assert mock_ep.assert_not_called() mock_etp.assert_not_called() @patch('torch.distributed.is_initialized', return_value=False) def test_init_ascend_model_parallel_assert_dist_not_init( self, mock_is_init): """Test assertion when distributed not initialized""" # Act & Assert with self.assertRaises(AssertionError): init_ascend_model_parallel() @patch('torch.distributed.is_initialized', return_value=True) @patch('torch.distributed.get_world_size', return_value=8) @patch('vllm_ascend.distributed.parallel_state.get_world_group', return_value=MagicMock(device_group='npu:0', local_rank=1)) @patch('torch.distributed.get_backend', return_value='hccl') @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False) def test_init_ascend_model_parallel_custom_params( self, mock_mp_init, mock_init_group, mock_get_backend, mock_world_group, mock_get_world_size, mock_is_init): """Test initialization with custom parallel sizes""" # Act init_ascend_model_parallel(expert_parallel_size=2, expert_tensor_parallel_size=4, world_size=8, backend='hccl') #Assert mock_init_group.assert_any_call([[0, 4], [1, 5], [2, 6], [3, 7]], 1, 'hccl', group_name="ep") mock_init_group.assert_any_call([[0, 1, 2, 3], [4, 5, 6, 7]], 1, 'hccl', group_name="etp")