diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 2e1661b..58cba6d 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -139,7 +139,6 @@ def mock_dist_env(mocker: MockerFixture): patch('torch.distributed.all_gather'), \ patch('torch.distributed.all_to_all_single'), \ patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ - patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_ascend_config', diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 0f0353d..ec8627c 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -66,8 +66,6 @@ def mock_dist_env(mocker: MockerFixture): patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce', return_value=torch.randn(5, 32)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter', - return_value=torch.randn(5, 32)), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config', diff --git a/vllm_ascend/distributed/communication_op.py b/vllm_ascend/distributed/communication_op.py deleted file mode 100644 index 2e475f5..0000000 --- a/vllm_ascend/distributed/communication_op.py +++ /dev/null @@ -1,25 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import torch -from vllm.distributed.parallel_state import get_dp_group - - -def data_parallel_reduce_scatter(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: - """Reduce-Scatter the input tensor across data parallel group.""" - return get_dp_group().reduce_scatter(input_, dim) diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index aa9bae8..5c6d8c6 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -7,12 +7,11 @@ import torch.nn as nn import torch_npu from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_dp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version @@ -147,7 +146,7 @@ class AllGatherCommImpl(MoECommMethod): When TP size > 1, all-reduce the hidden states to get the final output. """ if self.moe_config.dp_size > 1: - hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0) + hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) hidden_states = hidden_states[:self.num_tokens] if reduce_results and (self.moe_config.tp_size > 1 diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 11c4ec5..e5b4dff 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -40,8 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import \ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.layers.experts_selector import select_experts @@ -537,8 +535,8 @@ class AscendFusedMoE(FusedMoE): final_hidden_states = final_hidden_states[start:end, :] dispose_tensor(e_hidden_states) elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = data_parallel_reduce_scatter( - e_hidden_states, dim=0) + final_hidden_states = get_dp_group().reduce_scatter( + e_hidden_states, 0) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index bd2be21..1bab215 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -40,8 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import \ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.communication_op import \ - data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.sequence_parallel import MetadataForPadding @@ -1269,8 +1267,8 @@ class TorchairAscendFusedMoE(FusedMoE): final_hidden_states = final_hidden_states[start:end, :] dispose_tensor(e_hidden_states) elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = data_parallel_reduce_scatter( - e_hidden_states, dim=0) + final_hidden_states = get_dp_group().reduce_scatter( + e_hidden_states, 0) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: