[Test] Add ut for files in /distributed (#1951)
### What this PR does / why we need it?
Add some ut for files in folder /distributed
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.9.2
- vLLM main:
107111a859
Signed-off-by: lwq <liwenquan5@huawei.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
This commit is contained in:
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.distributed.device_communicators.pyhccl import \
|
||||||
|
PyHcclCommunicator
|
||||||
|
|
||||||
|
|
||||||
|
class MockHcclLib:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockUniqueId:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyHcclCommunicator(TestBase):
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"})
|
||||||
|
def test_world_size_1_return_early(self):
|
||||||
|
comm = PyHcclCommunicator(
|
||||||
|
group=StatelessProcessGroup(0, 1, None, None),
|
||||||
|
device="npu:0",
|
||||||
|
)
|
||||||
|
self.assertTrue(comm.disabled)
|
||||||
|
self.assertFalse(comm.available)
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"})
|
||||||
|
def test_load_hccl_fail(self):
|
||||||
|
comm = PyHcclCommunicator(group=StatelessProcessGroup(
|
||||||
|
0, 2, None, None),
|
||||||
|
device="npu:0",
|
||||||
|
library_path="/not/exist/path/libhccl.so")
|
||||||
|
self.assertTrue(comm.disabled)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||||
|
MockHcclLib)
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||||
|
MockUniqueId)
|
||||||
|
@patch("torch.npu.device")
|
||||||
|
@patch("vllm_ascend.utils.current_stream",
|
||||||
|
return_value=MagicMock(npu_stream=5678))
|
||||||
|
def test_stateless_group(self, *_):
|
||||||
|
group = StatelessProcessGroup(rank=3,
|
||||||
|
world_size=4,
|
||||||
|
store=None,
|
||||||
|
socket=None)
|
||||||
|
|
||||||
|
comm = PyHcclCommunicator(group=group, device=3)
|
||||||
|
|
||||||
|
self.assertEqual(comm.rank, 3)
|
||||||
|
self.assertEqual(comm.world_size, 4)
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"})
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||||
|
MockHcclLib)
|
||||||
|
@patch(
|
||||||
|
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||||
|
MockUniqueId)
|
||||||
|
@patch("torch.distributed.is_initialized", return_value=True)
|
||||||
|
@patch("torch.distributed.get_backend", return_value="nccl")
|
||||||
|
@patch("torch.distributed.get_rank", return_value=1)
|
||||||
|
@patch("torch.distributed.get_world_size", return_value=2)
|
||||||
|
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||||
|
@patch("torch.distributed.broadcast")
|
||||||
|
@patch("torch.npu.device")
|
||||||
|
@patch("vllm_ascend.utils.current_stream",
|
||||||
|
return_value=MagicMock(npu_stream=1234))
|
||||||
|
def test_multi_gpu_pg_torch(
|
||||||
|
self,
|
||||||
|
*_,
|
||||||
|
):
|
||||||
|
fake_pg = MagicMock()
|
||||||
|
comm = PyHcclCommunicator(group=fake_pg, device="npu:1")
|
||||||
|
|
||||||
|
self.assertEqual(comm.rank, 1)
|
||||||
|
self.assertEqual(comm.world_size, 2)
|
||||||
|
self.assertFalse(comm.available)
|
||||||
|
self.assertTrue(comm.disabled)
|
||||||
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||||
|
Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t,
|
||||||
|
hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum,
|
||||||
|
hcclResult_t, hcclUniqueId)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHcclUniqueId(TestBase):
|
||||||
|
|
||||||
|
def test_construct(self):
|
||||||
|
uid = hcclUniqueId()
|
||||||
|
uid.internal[0] = 12
|
||||||
|
self.assertEqual(len(uid.internal), 4108)
|
||||||
|
self.assertEqual(uid.internal[0], 12)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHcclDataTypeEnum(TestBase):
|
||||||
|
|
||||||
|
def test_torch_dtype_mapping(self):
|
||||||
|
expected = {
|
||||||
|
torch.int8: hcclDataTypeEnum.hcclInt8,
|
||||||
|
torch.uint8: hcclDataTypeEnum.hcclUint8,
|
||||||
|
torch.int32: hcclDataTypeEnum.hcclInt32,
|
||||||
|
torch.int64: hcclDataTypeEnum.hcclInt64,
|
||||||
|
torch.float16: hcclDataTypeEnum.hcclFloat16,
|
||||||
|
torch.float32: hcclDataTypeEnum.hcclFloat32,
|
||||||
|
torch.float64: hcclDataTypeEnum.hcclFloat64,
|
||||||
|
torch.bfloat16: hcclDataTypeEnum.hcclBfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
for torch_dtype, expected_enum in expected.items():
|
||||||
|
with self.subTest(torch_dtype=torch_dtype):
|
||||||
|
self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype),
|
||||||
|
expected_enum)
|
||||||
|
|
||||||
|
def test_unsupported_dtype_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
hcclDataTypeEnum.from_torch(torch.complex64)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHcclRedOpTypeEnum(TestBase):
|
||||||
|
|
||||||
|
def test_torch_reduce_op_mapping(self):
|
||||||
|
expected = {
|
||||||
|
ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum,
|
||||||
|
ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd,
|
||||||
|
ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax,
|
||||||
|
ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin,
|
||||||
|
}
|
||||||
|
|
||||||
|
for torch_op, expected_enum in expected.items():
|
||||||
|
with self.subTest(torch_op=torch_op):
|
||||||
|
self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op),
|
||||||
|
expected_enum)
|
||||||
|
|
||||||
|
def test_unsupported_op_raises(self):
|
||||||
|
unsupported_op = "NOT_EXIST"
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
hcclRedOpTypeEnum.from_torch(unsupported_op)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFunction(TestBase):
|
||||||
|
|
||||||
|
def test_construct_with_valid_args(self):
|
||||||
|
func = Function(name="foo", restype=int, argtypes=[int, str, float])
|
||||||
|
self.assertEqual(func.name, "foo")
|
||||||
|
self.assertIs(func.restype, int)
|
||||||
|
self.assertEqual(func.argtypes, [int, str, float])
|
||||||
|
|
||||||
|
|
||||||
|
class TestHCLLLibrary(TestBase):
|
||||||
|
|
||||||
|
def test_init_with_nonexistent_so(self):
|
||||||
|
fake_path = "/definitely/not/exist/libhccl.so"
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
HCCLLibrary(fake_path)
|
||||||
|
|
||||||
|
def test_hccl_get_error_string(self):
|
||||||
|
lib = MagicMock(sepc=HCCLLibrary)
|
||||||
|
mock_fn = MagicMock()
|
||||||
|
mock_fn.return_value = "HCCL internal error"
|
||||||
|
lib.hcclGetErrorString = mock_fn
|
||||||
|
|
||||||
|
result = hcclResult_t(1)
|
||||||
|
msg = lib.hcclGetErrorString(result)
|
||||||
|
self.assertEqual(msg, "HCCL internal error")
|
||||||
|
mock_fn.assert_called_once()
|
||||||
|
|
||||||
|
def test_hccl_check(self):
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
mock_fn = MagicMock()
|
||||||
|
mock_fn.return_value = "fake error"
|
||||||
|
lib.hcclGetErrorString = mock_fn
|
||||||
|
result = hcclResult_t(123)
|
||||||
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
|
lib.HCCL_CHECK(result)
|
||||||
|
|
||||||
|
self.assertEqual(str(cm.exception), "HCCL error: fake error")
|
||||||
|
|
||||||
|
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||||
|
def test_hccl_get_uniqueId(self, mock_HCCL_CHECK):
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)}
|
||||||
|
unique_id = lib.hcclGetUniqueId()
|
||||||
|
self.assertIsInstance(unique_id, hcclUniqueId)
|
||||||
|
lib._funcs["HcclGetRootInfo"].assert_called_once()
|
||||||
|
mock_HCCL_CHECK.assert_called_once_with(0)
|
||||||
|
|
||||||
|
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||||
|
def test_hccl_comm_initRank(self, mock_hccl_check):
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)}
|
||||||
|
|
||||||
|
world_size = 4
|
||||||
|
unique_id = hcclUniqueId()
|
||||||
|
rank = 1
|
||||||
|
|
||||||
|
comm = lib.hcclCommInitRank(world_size, unique_id, rank)
|
||||||
|
self.assertIsInstance(comm, hcclComm_t)
|
||||||
|
lib._funcs["HcclCommInitRootInfo"].assert_called_once()
|
||||||
|
mock_hccl_check.assert_called_once_with(0)
|
||||||
|
|
||||||
|
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||||
|
def test_hccl_all_reduce(self, mock_hccl_check):
|
||||||
|
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)}
|
||||||
|
sendbuff = buffer_type()
|
||||||
|
recvbuff = buffer_type()
|
||||||
|
count = 10
|
||||||
|
datatype = hcclDataType_t(1)
|
||||||
|
op = hcclRedOp_t(0)
|
||||||
|
comm = hcclComm_t()
|
||||||
|
stream = aclrtStream_t()
|
||||||
|
|
||||||
|
lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm,
|
||||||
|
stream)
|
||||||
|
|
||||||
|
lib._funcs["HcclAllReduce"].assert_called_once_with(
|
||||||
|
sendbuff, recvbuff, count, datatype, op, comm, stream)
|
||||||
|
mock_hccl_check.assert_called_once_with(0)
|
||||||
|
|
||||||
|
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||||
|
def test_hccl_broad_cast(self, mock_hccl_check):
|
||||||
|
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)}
|
||||||
|
buff = buffer_type()
|
||||||
|
count = 10
|
||||||
|
datatype = 1
|
||||||
|
root = 0
|
||||||
|
comm = hcclComm_t()
|
||||||
|
stream = aclrtStream_t()
|
||||||
|
|
||||||
|
lib.hcclBroadcast(buff, count, datatype, root, comm, stream)
|
||||||
|
|
||||||
|
lib._funcs["HcclBroadcast"].assert_called_once_with(
|
||||||
|
buff, count, datatype, root, comm, stream)
|
||||||
|
mock_hccl_check.assert_called_once_with(0)
|
||||||
|
|
||||||
|
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||||
|
def test_hcclCommDestroy_success(self, mock_hccl_check):
|
||||||
|
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||||
|
lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)}
|
||||||
|
comm = hcclComm_t()
|
||||||
|
lib.hcclCommDestroy(comm)
|
||||||
|
lib._funcs["HcclCommDestroy"].assert_called_once_with(comm)
|
||||||
|
mock_hccl_check.assert_called_once_with(0)
|
||||||
Reference in New Issue
Block a user