v0.10.1rc1
This commit is contained in:
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl import \
|
||||
PyHcclCommunicator
|
||||
|
||||
|
||||
class MockHcclLib:
|
||||
pass
|
||||
|
||||
|
||||
class MockUniqueId:
|
||||
pass
|
||||
|
||||
|
||||
class TestPyHcclCommunicator(TestBase):
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"})
|
||||
def test_world_size_1_return_early(self):
|
||||
comm = PyHcclCommunicator(
|
||||
group=StatelessProcessGroup(0, 1, None, None),
|
||||
device="npu:0",
|
||||
)
|
||||
self.assertTrue(comm.disabled)
|
||||
self.assertFalse(comm.available)
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"})
|
||||
def test_load_hccl_fail(self):
|
||||
comm = PyHcclCommunicator(group=StatelessProcessGroup(
|
||||
0, 2, None, None),
|
||||
device="npu:0",
|
||||
library_path="/not/exist/path/libhccl.so")
|
||||
self.assertTrue(comm.disabled)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||
MockHcclLib)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||
MockUniqueId)
|
||||
@patch("torch.npu.device")
|
||||
@patch("vllm_ascend.utils.current_stream",
|
||||
return_value=MagicMock(npu_stream=5678))
|
||||
def test_stateless_group(self, *_):
|
||||
group = StatelessProcessGroup(rank=3,
|
||||
world_size=4,
|
||||
store=None,
|
||||
socket=None)
|
||||
|
||||
comm = PyHcclCommunicator(group=group, device=3)
|
||||
|
||||
self.assertEqual(comm.rank, 3)
|
||||
self.assertEqual(comm.world_size, 4)
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"})
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||
MockHcclLib)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||
MockUniqueId)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="nccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.distributed.broadcast")
|
||||
@patch("torch.npu.device")
|
||||
@patch("vllm_ascend.utils.current_stream",
|
||||
return_value=MagicMock(npu_stream=1234))
|
||||
def test_multi_gpu_pg_torch(
|
||||
self,
|
||||
*_,
|
||||
):
|
||||
fake_pg = MagicMock()
|
||||
comm = PyHcclCommunicator(group=fake_pg, device="npu:1")
|
||||
|
||||
self.assertEqual(comm.rank, 1)
|
||||
self.assertEqual(comm.world_size, 2)
|
||||
self.assertFalse(comm.available)
|
||||
self.assertTrue(comm.disabled)
|
||||
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t,
|
||||
hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum,
|
||||
hcclResult_t, hcclUniqueId)
|
||||
|
||||
|
||||
class TestHcclUniqueId(TestBase):
|
||||
|
||||
def test_construct(self):
|
||||
uid = hcclUniqueId()
|
||||
uid.internal[0] = 12
|
||||
self.assertEqual(len(uid.internal), 4108)
|
||||
self.assertEqual(uid.internal[0], 12)
|
||||
|
||||
|
||||
class TestHcclDataTypeEnum(TestBase):
|
||||
|
||||
def test_torch_dtype_mapping(self):
|
||||
expected = {
|
||||
torch.int8: hcclDataTypeEnum.hcclInt8,
|
||||
torch.uint8: hcclDataTypeEnum.hcclUint8,
|
||||
torch.int32: hcclDataTypeEnum.hcclInt32,
|
||||
torch.int64: hcclDataTypeEnum.hcclInt64,
|
||||
torch.float16: hcclDataTypeEnum.hcclFloat16,
|
||||
torch.float32: hcclDataTypeEnum.hcclFloat32,
|
||||
torch.float64: hcclDataTypeEnum.hcclFloat64,
|
||||
torch.bfloat16: hcclDataTypeEnum.hcclBfloat16,
|
||||
}
|
||||
|
||||
for torch_dtype, expected_enum in expected.items():
|
||||
with self.subTest(torch_dtype=torch_dtype):
|
||||
self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype),
|
||||
expected_enum)
|
||||
|
||||
def test_unsupported_dtype_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
hcclDataTypeEnum.from_torch(torch.complex64)
|
||||
|
||||
|
||||
class TestHcclRedOpTypeEnum(TestBase):
|
||||
|
||||
def test_torch_reduce_op_mapping(self):
|
||||
expected = {
|
||||
ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum,
|
||||
ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd,
|
||||
ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax,
|
||||
ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin,
|
||||
}
|
||||
|
||||
for torch_op, expected_enum in expected.items():
|
||||
with self.subTest(torch_op=torch_op):
|
||||
self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op),
|
||||
expected_enum)
|
||||
|
||||
def test_unsupported_op_raises(self):
|
||||
unsupported_op = "NOT_EXIST"
|
||||
with self.assertRaises(ValueError):
|
||||
hcclRedOpTypeEnum.from_torch(unsupported_op)
|
||||
|
||||
|
||||
class TestFunction(TestBase):
|
||||
|
||||
def test_construct_with_valid_args(self):
|
||||
func = Function(name="foo", restype=int, argtypes=[int, str, float])
|
||||
self.assertEqual(func.name, "foo")
|
||||
self.assertIs(func.restype, int)
|
||||
self.assertEqual(func.argtypes, [int, str, float])
|
||||
|
||||
|
||||
class TestHCLLLibrary(TestBase):
|
||||
|
||||
def test_init_with_nonexistent_so(self):
|
||||
fake_path = "/definitely/not/exist/libhccl.so"
|
||||
with self.assertRaises(OSError):
|
||||
HCCLLibrary(fake_path)
|
||||
|
||||
def test_hccl_get_error_string(self):
|
||||
lib = MagicMock(sepc=HCCLLibrary)
|
||||
mock_fn = MagicMock()
|
||||
mock_fn.return_value = "HCCL internal error"
|
||||
lib.hcclGetErrorString = mock_fn
|
||||
|
||||
result = hcclResult_t(1)
|
||||
msg = lib.hcclGetErrorString(result)
|
||||
self.assertEqual(msg, "HCCL internal error")
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
def test_hccl_check(self):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
mock_fn = MagicMock()
|
||||
mock_fn.return_value = "fake error"
|
||||
lib.hcclGetErrorString = mock_fn
|
||||
result = hcclResult_t(123)
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lib.HCCL_CHECK(result)
|
||||
|
||||
self.assertEqual(str(cm.exception), "HCCL error: fake error")
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_get_uniqueId(self, mock_HCCL_CHECK):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)}
|
||||
unique_id = lib.hcclGetUniqueId()
|
||||
self.assertIsInstance(unique_id, hcclUniqueId)
|
||||
lib._funcs["HcclGetRootInfo"].assert_called_once()
|
||||
mock_HCCL_CHECK.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_comm_initRank(self, mock_hccl_check):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)}
|
||||
|
||||
world_size = 4
|
||||
unique_id = hcclUniqueId()
|
||||
rank = 1
|
||||
|
||||
comm = lib.hcclCommInitRank(world_size, unique_id, rank)
|
||||
self.assertIsInstance(comm, hcclComm_t)
|
||||
lib._funcs["HcclCommInitRootInfo"].assert_called_once()
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_all_reduce(self, mock_hccl_check):
|
||||
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)}
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type()
|
||||
count = 10
|
||||
datatype = hcclDataType_t(1)
|
||||
op = hcclRedOp_t(0)
|
||||
comm = hcclComm_t()
|
||||
stream = aclrtStream_t()
|
||||
|
||||
lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm,
|
||||
stream)
|
||||
|
||||
lib._funcs["HcclAllReduce"].assert_called_once_with(
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_broad_cast(self, mock_hccl_check):
|
||||
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)}
|
||||
buff = buffer_type()
|
||||
count = 10
|
||||
datatype = 1
|
||||
root = 0
|
||||
comm = hcclComm_t()
|
||||
stream = aclrtStream_t()
|
||||
|
||||
lib.hcclBroadcast(buff, count, datatype, root, comm, stream)
|
||||
|
||||
lib._funcs["HcclBroadcast"].assert_called_once_with(
|
||||
buff, count, datatype, root, comm, stream)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hcclCommDestroy_success(self, mock_hccl_check):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)}
|
||||
comm = hcclComm_t()
|
||||
lib.hcclCommDestroy(comm)
|
||||
lib._funcs["HcclCommDestroy"].assert_called_once_with(comm)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
89
tests/ut/distributed/test_communicator.py
Normal file
89
tests/ut/distributed/test_communicator.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.distributed.communicator import NPUCommunicator
|
||||
|
||||
|
||||
class TestNPUCommunicator(unittest.TestCase):
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_all_to_all_with_sizes(self, *_):
|
||||
|
||||
def patched_all_to_all(output_tensor_list,
|
||||
input_tensor_list,
|
||||
group=None,
|
||||
async_op=False):
|
||||
output_tensor_list[:] = ([
|
||||
torch.tensor([10, 20]),
|
||||
torch.tensor([50, 60])
|
||||
])
|
||||
|
||||
torch.distributed.all_to_all = patched_all_to_all
|
||||
|
||||
scatter_sizes = [2, 2]
|
||||
gather_sizes = [2, 2]
|
||||
input_ = torch.tensor([10, 20, 30, 40])
|
||||
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
|
||||
output = comm.all_to_all(input_,
|
||||
scatter_sizes=scatter_sizes,
|
||||
gather_sizes=gather_sizes)
|
||||
|
||||
assert output.tolist() == [10, 20, 50, 60]
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_all_to_all_without_sizes(self, *_):
|
||||
|
||||
def patched_all_to_all(output_tensor_list,
|
||||
input_tensor_list,
|
||||
group=None,
|
||||
async_op=False):
|
||||
output_tensor_list[:] = ([
|
||||
torch.tensor([[10, 20]]),
|
||||
torch.tensor([[50, 60]])
|
||||
])
|
||||
|
||||
torch.distributed.all_to_all = patched_all_to_all
|
||||
|
||||
input_ = torch.tensor([[10, 20], [30, 40]])
|
||||
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
|
||||
|
||||
assert output.tolist() == [[10, 20], [50, 60]]
|
||||
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
_gather_along_first_dim, _gather_along_last_dim,
|
||||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
|
||||
all_to_all_hp2sp, all_to_all_sp2hp)
|
||||
|
||||
|
||||
class TestDistributedCommunication(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.distributed.get_world_size", return_value=4)
|
||||
|
||||
mocker.patch("torch.distributed.get_rank", return_value=0)
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16), (8, 16)),
|
||||
(4, torch.randn(8, 16), (32, 16))])
|
||||
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
|
||||
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
|
||||
])
|
||||
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
|
||||
output_split_sizes,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
|
||||
output_split_sizes)
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
|
||||
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
|
||||
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_last_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((32, 16), (8, 16)),
|
||||
((40, 10), (10, 10)),
|
||||
])
|
||||
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_first_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((8, 16, 32), (8, 16, 8)),
|
||||
])
|
||||
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_last_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("func,input_shape,expected_shape", [
|
||||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 128)),
|
||||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
|
||||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 8)),
|
||||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
|
||||
])
|
||||
def test_wrapper_functions(self, func, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
"""test wrapper funcs"""
|
||||
mod = importlib.import_module(
|
||||
'vllm_ascend.distributed.tensor_parallel')
|
||||
globals = mod.__dict__
|
||||
test_func = globals[func]
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = test_func(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
|
||||
])
|
||||
def test_all_to_all_sp2hp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
|
||||
])
|
||||
def test_all_to_all_hp2sp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
44
tests/ut/distributed/test_parallel_state.py
Normal file
44
tests/ut/distributed/test_parallel_state.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parallel_config():
|
||||
return ParallelConfig(data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
with patch('torch.distributed.is_initialized', return_value=True), \
|
||||
patch('torch.distributed.get_world_size', return_value=8), \
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
assert mc2_group is not None
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
assert lmheadtp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
Reference in New Issue
Block a user