From 34571ea5ae69529758edf75f0252f86ccb4c7184 Mon Sep 17 00:00:00 2001 From: SunnyLee151064 <83354342+SunnyLee151064@users.noreply.github.com> Date: Thu, 24 Jul 2025 10:36:11 +0800 Subject: [PATCH] [Test] Add ut for files in /distributed (#1951) ### What this PR does / why we need it? Add some ut for files in folder /distributed ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/107111a85991cbe9c8d16b697e048d91aa4135f6 Signed-off-by: lwq Co-authored-by: lwq --- .../device_communicators/test_pyhccl.py | 84 +++++++++ .../test_pyhccl_wrapper.py | 173 ++++++++++++++++++ 2 files changed, 257 insertions(+) create mode 100644 tests/ut/distributed/device_communicators/test_pyhccl.py create mode 100644 tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py diff --git a/tests/ut/distributed/device_communicators/test_pyhccl.py b/tests/ut/distributed/device_communicators/test_pyhccl.py new file mode 100644 index 0000000..16eb095 --- /dev/null +++ b/tests/ut/distributed/device_communicators/test_pyhccl.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import MagicMock, patch + +from vllm.distributed.utils import StatelessProcessGroup + +from tests.ut.base import TestBase +from vllm_ascend.distributed.device_communicators.pyhccl import \ + PyHcclCommunicator + + +class MockHcclLib: + pass + + +class MockUniqueId: + pass + + +class TestPyHcclCommunicator(TestBase): + + @patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"}) + def test_world_size_1_return_early(self): + comm = PyHcclCommunicator( + group=StatelessProcessGroup(0, 1, None, None), + device="npu:0", + ) + self.assertTrue(comm.disabled) + self.assertFalse(comm.available) + + @patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"}) + def test_load_hccl_fail(self): + comm = PyHcclCommunicator(group=StatelessProcessGroup( + 0, 2, None, None), + device="npu:0", + library_path="/not/exist/path/libhccl.so") + self.assertTrue(comm.disabled) + + @patch( + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", + MockHcclLib) + @patch( + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", + MockUniqueId) + @patch("torch.npu.device") + @patch("vllm_ascend.utils.current_stream", + return_value=MagicMock(npu_stream=5678)) + def test_stateless_group(self, *_): + group = StatelessProcessGroup(rank=3, + world_size=4, + store=None, + socket=None) + + comm = PyHcclCommunicator(group=group, device=3) + + self.assertEqual(comm.rank, 3) + self.assertEqual(comm.world_size, 4) + + @patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"}) + @patch( + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", + MockHcclLib) + @patch( + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", + MockUniqueId) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("torch.distributed.get_backend", return_value="nccl") + @patch("torch.distributed.get_rank", return_value=1) + @patch("torch.distributed.get_world_size", return_value=2) + @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) + @patch("torch.distributed.broadcast") + @patch("torch.npu.device") + @patch("vllm_ascend.utils.current_stream", + return_value=MagicMock(npu_stream=1234)) + def test_multi_gpu_pg_torch( + self, + *_, + ): + fake_pg = MagicMock() + comm = PyHcclCommunicator(group=fake_pg, device="npu:1") + + self.assertEqual(comm.rank, 1) + self.assertEqual(comm.world_size, 2) + self.assertFalse(comm.available) + self.assertTrue(comm.disabled) diff --git a/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py b/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py new file mode 100644 index 0000000..ff90512 --- /dev/null +++ b/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py @@ -0,0 +1,173 @@ +from unittest.mock import MagicMock, patch + +import torch +from torch.distributed import ReduceOp + +from tests.ut.base import TestBase +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, + hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum, + hcclResult_t, hcclUniqueId) + + +class TestHcclUniqueId(TestBase): + + def test_construct(self): + uid = hcclUniqueId() + uid.internal[0] = 12 + self.assertEqual(len(uid.internal), 4108) + self.assertEqual(uid.internal[0], 12) + + +class TestHcclDataTypeEnum(TestBase): + + def test_torch_dtype_mapping(self): + expected = { + torch.int8: hcclDataTypeEnum.hcclInt8, + torch.uint8: hcclDataTypeEnum.hcclUint8, + torch.int32: hcclDataTypeEnum.hcclInt32, + torch.int64: hcclDataTypeEnum.hcclInt64, + torch.float16: hcclDataTypeEnum.hcclFloat16, + torch.float32: hcclDataTypeEnum.hcclFloat32, + torch.float64: hcclDataTypeEnum.hcclFloat64, + torch.bfloat16: hcclDataTypeEnum.hcclBfloat16, + } + + for torch_dtype, expected_enum in expected.items(): + with self.subTest(torch_dtype=torch_dtype): + self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype), + expected_enum) + + def test_unsupported_dtype_raises(self): + with self.assertRaises(ValueError): + hcclDataTypeEnum.from_torch(torch.complex64) + + +class TestHcclRedOpTypeEnum(TestBase): + + def test_torch_reduce_op_mapping(self): + expected = { + ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum, + ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd, + ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax, + ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin, + } + + for torch_op, expected_enum in expected.items(): + with self.subTest(torch_op=torch_op): + self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op), + expected_enum) + + def test_unsupported_op_raises(self): + unsupported_op = "NOT_EXIST" + with self.assertRaises(ValueError): + hcclRedOpTypeEnum.from_torch(unsupported_op) + + +class TestFunction(TestBase): + + def test_construct_with_valid_args(self): + func = Function(name="foo", restype=int, argtypes=[int, str, float]) + self.assertEqual(func.name, "foo") + self.assertIs(func.restype, int) + self.assertEqual(func.argtypes, [int, str, float]) + + +class TestHCLLLibrary(TestBase): + + def test_init_with_nonexistent_so(self): + fake_path = "/definitely/not/exist/libhccl.so" + with self.assertRaises(OSError): + HCCLLibrary(fake_path) + + def test_hccl_get_error_string(self): + lib = MagicMock(sepc=HCCLLibrary) + mock_fn = MagicMock() + mock_fn.return_value = "HCCL internal error" + lib.hcclGetErrorString = mock_fn + + result = hcclResult_t(1) + msg = lib.hcclGetErrorString(result) + self.assertEqual(msg, "HCCL internal error") + mock_fn.assert_called_once() + + def test_hccl_check(self): + lib = HCCLLibrary.__new__(HCCLLibrary) + mock_fn = MagicMock() + mock_fn.return_value = "fake error" + lib.hcclGetErrorString = mock_fn + result = hcclResult_t(123) + with self.assertRaises(RuntimeError) as cm: + lib.HCCL_CHECK(result) + + self.assertEqual(str(cm.exception), "HCCL error: fake error") + + @patch.object(HCCLLibrary, "HCCL_CHECK") + def test_hccl_get_uniqueId(self, mock_HCCL_CHECK): + lib = HCCLLibrary.__new__(HCCLLibrary) + lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)} + unique_id = lib.hcclGetUniqueId() + self.assertIsInstance(unique_id, hcclUniqueId) + lib._funcs["HcclGetRootInfo"].assert_called_once() + mock_HCCL_CHECK.assert_called_once_with(0) + + @patch.object(HCCLLibrary, "HCCL_CHECK") + def test_hccl_comm_initRank(self, mock_hccl_check): + lib = HCCLLibrary.__new__(HCCLLibrary) + lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)} + + world_size = 4 + unique_id = hcclUniqueId() + rank = 1 + + comm = lib.hcclCommInitRank(world_size, unique_id, rank) + self.assertIsInstance(comm, hcclComm_t) + lib._funcs["HcclCommInitRootInfo"].assert_called_once() + mock_hccl_check.assert_called_once_with(0) + + @patch.object(HCCLLibrary, "HCCL_CHECK") + def test_hccl_all_reduce(self, mock_hccl_check): + + lib = HCCLLibrary.__new__(HCCLLibrary) + lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)} + sendbuff = buffer_type() + recvbuff = buffer_type() + count = 10 + datatype = hcclDataType_t(1) + op = hcclRedOp_t(0) + comm = hcclComm_t() + stream = aclrtStream_t() + + lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm, + stream) + + lib._funcs["HcclAllReduce"].assert_called_once_with( + sendbuff, recvbuff, count, datatype, op, comm, stream) + mock_hccl_check.assert_called_once_with(0) + + @patch.object(HCCLLibrary, "HCCL_CHECK") + def test_hccl_broad_cast(self, mock_hccl_check): + + lib = HCCLLibrary.__new__(HCCLLibrary) + lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)} + buff = buffer_type() + count = 10 + datatype = 1 + root = 0 + comm = hcclComm_t() + stream = aclrtStream_t() + + lib.hcclBroadcast(buff, count, datatype, root, comm, stream) + + lib._funcs["HcclBroadcast"].assert_called_once_with( + buff, count, datatype, root, comm, stream) + mock_hccl_check.assert_called_once_with(0) + + @patch.object(HCCLLibrary, "HCCL_CHECK") + def test_hcclCommDestroy_success(self, mock_hccl_check): + lib = HCCLLibrary.__new__(HCCLLibrary) + lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)} + comm = hcclComm_t() + lib.hcclCommDestroy(comm) + lib._funcs["HcclCommDestroy"].assert_called_once_with(comm) + mock_hccl_check.assert_called_once_with(0)