diff --git a/vllm_ascend/distributed/communication_op.py b/vllm_ascend/distributed/communication_op.py new file mode 100644 index 0000000..2e475f5 --- /dev/null +++ b/vllm_ascend/distributed/communication_op.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.distributed.parallel_state import get_dp_group + + +def data_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across data parallel group.""" + return get_dp_group().reduce_scatter(input_, dim) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0197cb3..7647536 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -39,6 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import \ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.communication_op import \ + data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, @@ -1342,11 +1344,8 @@ class AscendFusedMoE(FusedMoE): final_hidden_states = final_hidden_states[start:end, :] dispose_tensor(e_hidden_states) elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) + final_hidden_states = data_parallel_reduce_scatter( + e_hidden_states, dim=0) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: