From ee6f79c44a85720fd3564edafd3a7b2ac4adbf82 Mon Sep 17 00:00:00 2001 From: yangqinghao-cmss Date: Sat, 9 Aug 2025 08:26:04 +0800 Subject: [PATCH] Add ut for test_communicator.py (#2293) ### What this PR does / why we need it? Add ut for test_communicator.py - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/e5ebeeba531755a78f68413e88a23d061404f3e3 Signed-off-by: yangqinghao-cmss --- tests/ut/distributed/test_communicator.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 tests/ut/distributed/test_communicator.py diff --git a/tests/ut/distributed/test_communicator.py b/tests/ut/distributed/test_communicator.py new file mode 100644 index 0000000..880cb24 --- /dev/null +++ b/tests/ut/distributed/test_communicator.py @@ -0,0 +1,155 @@ +import unittest +from unittest.mock import MagicMock, Mock, patch + +import torch +import torch.distributed as dist + +from vllm_ascend.distributed.communicator import NPUCommunicator + + +class TestNPUCommunicator(unittest.TestCase): + + @patch("vllm.config.get_current_vllm_config", return_value=None) + @patch("torch.npu.current_device", return_value=MagicMock()) + @patch("torch.npu.set_device", return_value=MagicMock()) + @patch("torch.distributed.get_process_group_ranks", + return_value={ + 0: 0, + 1: 1 + }) + @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_backend", return_value="hccl") + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.get_world_size", return_value=2) + @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) + @patch("torch.npu.device") + def test_all_to_all_with_sizes(self, *_): + + def patched_all_to_all(output_tensor_list, + input_tensor_list, + group=None, + async_op=False): + output_tensor_list[:] = ([ + torch.tensor([10, 20]), + torch.tensor([50, 60]) + ]) + + torch.distributed.all_to_all = patched_all_to_all + + scatter_sizes = [2, 2] + gather_sizes = [2, 2] + input_ = torch.tensor([10, 20, 30, 40]) + + comm = NPUCommunicator(cpu_group=dist.group.WORLD) + + output = comm.all_to_all(input_, + scatter_sizes=scatter_sizes, + gather_sizes=gather_sizes) + + assert output.tolist() == [10, 20, 50, 60] + + @patch("vllm.config.get_current_vllm_config", return_value=None) + @patch("torch.npu.current_device", return_value=MagicMock()) + @patch("torch.npu.set_device", return_value=MagicMock()) + @patch("torch.distributed.get_process_group_ranks", + return_value={ + 0: 0, + 1: 1 + }) + @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_backend", return_value="hccl") + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.get_world_size", return_value=2) + @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) + @patch("torch.npu.device") + def test_all_to_all_without_sizes(self, *_): + + def patched_all_to_all(output_tensor_list, + input_tensor_list, + group=None, + async_op=False): + output_tensor_list[:] = ([ + torch.tensor([[10, 20]]), + torch.tensor([[50, 60]]) + ]) + + torch.distributed.all_to_all = patched_all_to_all + + input_ = torch.tensor([[10, 20], [30, 40]]) + + comm = NPUCommunicator(cpu_group=dist.group.WORLD) + output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0) + + assert output.tolist() == [[10, 20], [50, 60]] + + @patch("vllm.config.get_current_vllm_config", return_value=None) + @patch("torch.npu.current_device", return_value=MagicMock()) + @patch("torch.npu.set_device", return_value=MagicMock()) + @patch("torch.distributed.get_process_group_ranks", + return_value={ + 0: 0, + 1: 1 + }) + @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_backend", return_value="hccl") + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.get_world_size", return_value=2) + @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) + @patch("torch.npu.device") + def test_dispatch(self, *_): + comm = NPUCommunicator(cpu_group=dist.group.WORLD) + comm.all2all_manager = Mock() + hidden_states = torch.randn(2, 4, 8) + router_logits = torch.randn(2, 4, 2) + + mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2)) + comm.all2all_manager.dispatch.return_value = mock_dispatch_result + + result_hidden, result_logits = comm.dispatch(hidden_states, + router_logits) + + assert torch.allclose(result_hidden, mock_dispatch_result[0]) + assert torch.allclose(result_logits, mock_dispatch_result[1]) + + comm.all2all_manager.dispatch.assert_called_once_with( + hidden_states, router_logits) + + @patch("vllm.config.get_current_vllm_config", return_value=None) + @patch("torch.npu.current_device", return_value=MagicMock()) + @patch("torch.npu.set_device", return_value=MagicMock()) + @patch("torch.distributed.get_process_group_ranks", + return_value={ + 0: 0, + 1: 1 + }) + @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_backend", return_value="hccl") + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.get_world_size", return_value=2) + @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) + @patch("torch.npu.device") + def test_combine(self, *_): + comm = NPUCommunicator(cpu_group=dist.group.WORLD) + comm.all2all_manager = Mock() + hidden_states = torch.randn(2, 4, 8) + + mock_combine_result = torch.randn(2, 4, 8) + comm.all2all_manager.combine.return_value = mock_combine_result + + result = comm.combine(hidden_states) + + assert torch.allclose(result, mock_combine_result) + + comm.all2all_manager.combine.assert_called_once_with(hidden_states)