diff --git a/tests/ut/distributed/test_distributed_tensor_parallel.py b/tests/ut/distributed/test_distributed_tensor_parallel.py deleted file mode 100644 index 48a88fa..0000000 --- a/tests/ut/distributed/test_distributed_tensor_parallel.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import importlib - -import pytest -import torch -from pytest_mock import MockerFixture - -from tests.ut.base import PytestBase -from vllm_ascend.distributed.tensor_parallel import ( - _gather_along_first_dim, _gather_along_last_dim, - _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, - all_to_all_hp2sp, all_to_all_sp2hp) - - -class TestDistributedCommunication(PytestBase): - - @pytest.fixture(autouse=True) - def context(self, mocker: MockerFixture): - mocker.patch("torch.npu.current_device", return_value="cpu") - mocker.patch("torch.distributed.get_world_size", return_value=4) - - mocker.patch("torch.distributed.get_rank", return_value=0) - - @pytest.mark.parametrize("world_size, test_tensor, expected", - [(1, torch.randn(8, 16), (8, 16)), - (4, torch.randn(8, 16), (32, 16))]) - def test_gather_along_first_dim(self, test_tensor, expected, world_size, - mocker: MockerFixture): - """test _gather_along_first_dim""" - mocker.patch("torch.distributed.get_world_size", - return_value=world_size) - - result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) - - assert result.shape == expected - - @pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [ - (torch.randn(8, 16), [5, 10, 15, 2], (32, 16)), - ]) - def test_gather_along_first_dim_unequal_split(self, test_tensor, expected, - output_split_sizes, - mocker: MockerFixture): - """test _gather_along_first_dim""" - - result = _gather_along_first_dim(test_tensor, mocker.MagicMock(), - output_split_sizes) - - assert result.shape == expected - - @pytest.mark.parametrize("world_size, test_tensor, expected", - [(1, torch.randn(8, 16, 32), (8, 16, 32)), - (4, torch.randn(8, 16, 32), (8, 16, 32 * 4))]) - def test_gather_along_last_dim(self, test_tensor, expected, world_size, - mocker: MockerFixture): - """test _gather_along_last_dim""" - mocker.patch("torch.distributed.get_world_size", - return_value=world_size) - - result = _gather_along_last_dim(test_tensor, mocker.MagicMock()) - - assert result.shape == expected - - @pytest.mark.parametrize("input_shape,expected_shape", [ - ((32, 16), (8, 16)), - ((40, 10), (10, 10)), - ]) - def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = _reduce_scatter_along_first_dim(input_tensor, - mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize("input_shape,expected_shape", [ - ((8, 16, 32), (8, 16, 8)), - ]) - def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = _reduce_scatter_along_last_dim(input_tensor, - mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize("func,input_shape,expected_shape", [ - ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), - (8, 16, 128)), - ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), - ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), - (8, 16, 8)), - ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), - ]) - def test_wrapper_functions(self, func, input_shape, expected_shape, - mocker: MockerFixture): - """test wrapper funcs""" - mod = importlib.import_module( - 'vllm_ascend.distributed.tensor_parallel') - globals = mod.__dict__ - test_func = globals[func] - input_tensor = torch.randn(*input_shape) - result = test_func(input_tensor, mocker.MagicMock()) - assert result.shape == expected_shape - - @pytest.mark.parametrize( - "input_shape,output_shape", - [ - ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] - ]) - def test_all_to_all_sp2hp(self, input_shape, output_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = all_to_all_sp2hp(input_tensor, mocker.MagicMock()) - assert result.shape == output_shape - - @pytest.mark.parametrize( - "input_shape,output_shape", - [ - ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] - ]) - def test_all_to_all_hp2sp(self, input_shape, output_shape, - mocker: MockerFixture): - input_tensor = torch.randn(*input_shape) - result = all_to_all_hp2sp(input_tensor, mocker.MagicMock()) - assert result.shape == output_shape diff --git a/tests/ut/ops/test_comm_utils.py b/tests/ut/ops/test_comm_utils.py new file mode 100644 index 0000000..5b4071c --- /dev/null +++ b/tests/ut/ops/test_comm_utils.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import pytest +import torch +from pytest_mock import MockerFixture + +from tests.ut.base import PytestBase +from vllm_ascend.ops.moe.comm_utils import ( + _gather_along_first_dim, async_all_to_all, + gather_from_sequence_parallel_region) + + +class TestDistributedCommunication(PytestBase): + + @pytest.fixture(autouse=True) + def context(self, mocker: MockerFixture): + mocker.patch("torch.npu.current_device", return_value="cpu") + mocker.patch("torch.distributed.get_world_size", return_value=4) + + mocker.patch("torch.distributed.get_rank", return_value=0) + + @pytest.mark.parametrize( + "input_tensor, output_split_sizes, input_split_sizes", + [(torch.randn(8, 16), [2, 2, 2, 2], [2, 2, 2, 2]), + (torch.randn(16, 32), None, None)]) + def test_async_all_to_all(self, input_tensor, output_split_sizes, + input_split_sizes, mocker: MockerFixture): + """Test async_all_to_all""" + mock_group = mocker.MagicMock() + mocker.patch("torch.distributed.all_to_all_single", + return_value=mocker.MagicMock()) + + _, a2a_out, handle = async_all_to_all(input_tensor, output_split_sizes, + input_split_sizes, mock_group) + + # Check if the output tensor is created properly + if output_split_sizes is None: + assert a2a_out.shape == input_tensor.shape + else: + total_output_size = sum(output_split_sizes) + expected_shape = [total_output_size] + list( + input_tensor.size())[1:] + assert a2a_out.shape == torch.Size(expected_shape) + + # Ensure handle is returned from async operation + assert handle is not None + assert isinstance(handle, mocker.MagicMock) + + @pytest.mark.parametrize("world_size, test_tensor, expected", + [(1, torch.randn(8, 16), (8, 16)), + (4, torch.randn(8, 16), (32, 16))]) + def test_gather_along_first_dim(self, test_tensor, expected, world_size, + mocker: MockerFixture): + """Test _gather_along_first_dim""" + mocker.patch("torch.distributed.get_world_size", + return_value=world_size) + + result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) + + assert result.shape == expected + + @pytest.mark.parametrize("input_tensor, output_split_sizes", + [(torch.randn(8, 16), None), + (torch.randn(8, 16), [2, 2, 2, 2])]) + def test_gather_from_sequence_parallel_region(self, input_tensor, + output_split_sizes, + mocker: MockerFixture): + """Test gather_from_sequence_parallel_region""" + mock_group = mocker.MagicMock() + + result = gather_from_sequence_parallel_region(input_tensor, mock_group, + output_split_sizes) + + # If output_split_sizes is not provided, result should have expanded first dimension by world size + if output_split_sizes is None: + expected_shape = [input_tensor.shape[0] * 4] + list( + input_tensor.shape[1:]) + assert result.shape == torch.Size(expected_shape) + else: + # If output_split_sizes is provided, result shape is dictated by sum of output_split_sizes + expected_shape = [sum(output_split_sizes)] + list( + input_tensor.shape[1:]) + assert result.shape == torch.Size(expected_shape) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 1416d41..4273d26 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -348,7 +348,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) # Mock async_all_to_all - patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all') + patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all') self.mock_async_all_to_all = patcher6.start() self.addCleanup(patcher6.stop) self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py deleted file mode 100644 index 3fff0a7..0000000 --- a/vllm_ascend/distributed/tensor_parallel.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py. -# This file is a part of the vllm-ascend project. -import torch - - -def _gather_along_first_dim(input_, group, output_split_sizes=None): - """Gather tensors and concatenate along the first dimension. - - Args: - input_tensor (torch.Tensor): - A tensor to be gathered. - output_split_sizes (List[int], optional): - A list specifying the sizes of the output splits along the first dimension. - If None, equal splitting is assumed. Default: None. - - Returns: - torch.Tensor: Gathered tensor. - """ - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - if output_split_sizes is None: - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.all_gather_into_tensor(output, - input_.contiguous(), - group=group) - else: - dim_size[0] = sum(output_split_sizes) - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - output_tensor_list = list( - torch.split(output, output_split_sizes, dim=0)) - torch.distributed.all_gather(output_tensor_list, input_, group=group) - - return output - - -def _gather_along_last_dim(input_, group): - """Gather tensors and concatenate along the last dimension.""" - - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.all_gather_into_tensor(output, - input_.contiguous(), - group=group) - tensor_list = output.chunk(world_size, dim=0) - output = torch.cat(tensor_list, dim=-1).contiguous() - - return output - - -def _reduce_scatter_along_first_dim(input_, - group, - input_split_sizes=None, - use_global_buffer=False): - """Reduce-scatter the input tensor across model parallel group. - - Args: - input_ (torch.Tensor): The input tensor to be reduce-scattered. - input_split_sizes (List[int], optional): A list specifying the sizes of - the input splits along the first dimension for each rank. If None, - equal splitting is assumed. Default: None. - """ - world_size = torch.distributed.get_world_size(group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - if input_split_sizes is None: - dim_size = list(input_.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" - - dim_size[0] = dim_size[0] // world_size - - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.reduce_scatter_tensor(output, - input_.contiguous(), - group=group) - else: - rank = torch.distributed.get_rank(group) - input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) - - output = torch.empty_like(input_tensor_list[rank]) - torch.distributed.reduce_scatter(output, - input_tensor_list, - group=group) - return output - - -def _reduce_scatter_along_last_dim(input_, group): - """Reduce-scatter tensors on the last dimension.""" - world_size = torch.distributed.get_world_size(group) - target_shape = list(input_.size()) - target_shape[-1] = target_shape[-1] // world_size - input_ = input_.reshape(-1, input_.shape[-1]) - split_tensors = torch.split(input_, - split_size_or_sections=input_.shape[-1] // - world_size, - dim=1) - concat_tensor = torch.cat(split_tensors, dim=0) - output = _reduce_scatter_along_first_dim(concat_tensor, - group).reshape(target_shape) - return output - - -def all_gather_last_dim_from_tensor_parallel_region(input_, group): - """Wrapper for autograd function: forward: AG, backward RS """ - return _gather_along_last_dim(input_, group) - - -def reduce_scatter_to_sequence_parallel_region(input_, - group, - input_split_sizes=None): - """Wrapper for autograd function: forward: RS, backward AG """ - return _reduce_scatter_along_first_dim(input_, group, input_split_sizes) - - -def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group): - """Wrapper for autograd function: forward: RS, backward AG: AG """ - return _reduce_scatter_along_last_dim(input_, group) - - -def gather_from_sequence_parallel_region( - input_, - group, - output_split_sizes=None, -): - """Wrapper for autograd function: forward: AG, backward: RS """ - return _gather_along_first_dim(input_, group, output_split_sizes) - - -def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None): - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input - - input = input.contiguous() - if output_split_sizes is None: - # Equal split (all2all) - output = torch.empty_like(input) - else: - # Unequal split (all2all-v) - output = input.new_empty( - size=[sum(output_split_sizes)] + list(input.size()[1:]), - dtype=input.dtype, - device=torch.npu.current_device(), - ) - torch.distributed.all_to_all_single( - output, - input, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - ) - return output - - -def all_to_all_sp2hp(input_, group): - """ - Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape - [num_tokens/TP, H] to [num_tokens, H/TP]. - - Args: - input_ (torch.Tensor): - The input tensor which has been distributed along the sequence - dimension. - - Returns: - torch.Tensor: The output tensor with shape [num_tokens, H/TP]. - - """ - if group is None: - return input_ - world_size = torch.distributed.get_world_size(group=group) - tp_group = group - input_ = input_.reshape(-1, input_.shape[-1]) - split_tensors = torch.split(input_, - split_size_or_sections=input_.shape[-1] // - world_size, - dim=1) - concat_tensor = torch.cat(split_tensors, dim=0) - output = all_to_all(tp_group, concat_tensor) - return output - - -def all_to_all_hp2sp(input_, group): - """ - Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape - [num_tokens, H/TP] to [num_tokens/TP, H]. - - Args: - input_ (torch.Tensor): - The input tensor which has been distributed along the hidden - dimension. - - Returns: - torch.Tensor: The output tensor with shape [num_tokens/TP, H]. - """ - if group is None: - return input_ - world_size = torch.distributed.get_world_size(group=group) - input_ = input_.reshape(-1, input_.shape[-1]) - tp_group = group - input_exchanged = all_to_all(tp_group, input_) - input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1]) - split_tensors = torch.split( - input_reshaped, - split_size_or_sections=input_reshaped.shape[0] // world_size, - dim=0) - output = torch.cat(split_tensors, dim=-1) - return output diff --git a/vllm_ascend/ops/comm_utils.py b/vllm_ascend/ops/moe/comm_utils.py similarity index 55% rename from vllm_ascend/ops/comm_utils.py rename to vllm_ascend/ops/moe/comm_utils.py index e893049..b8952a9 100644 --- a/vllm_ascend/ops/comm_utils.py +++ b/vllm_ascend/ops/moe/comm_utils.py @@ -1,5 +1,7 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. +# This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# This file is a part of the vllm-ascend project. +# import torch import torch.distributed import torch.distributed as dist @@ -60,3 +62,52 @@ def async_all_to_all(input_, group=group, async_op=True) return input_, a2a_out, handle + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) \ No newline at end of file diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 72a1c34..f3aba2b 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -30,9 +30,8 @@ from vllm.distributed.parallel_state import get_ep_group import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.distributed.tensor_parallel import \ - gather_from_sequence_parallel_region -from vllm_ascend.ops.comm_utils import async_all_to_all +from vllm_ascend.ops.moe.comm_utils import ( + async_all_to_all, gather_from_sequence_parallel_region) from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version _Dispatchers: Dict[str, Any] = {}