Files
xc-llm-ascend/vllm_ascend/ops/moe/comm_utils.py
wuweiqiang24 9615dea3a7 Refactor tensor_parallel and comm_utils (#2814)
### What this PR does / why we need it?
1. Move ops/comm_utils to ops/moe/comm_utils
2. Move distributed/tensor_parallel/gather_from_sequence_parallel_region
to ops/moe/comm_utils
3. Delete distributed/tensor_parallel

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
e2e & ut

- vLLM version: main
- vLLM main:
a1213fae5f

---------

Signed-off-by: wuweiqiang24 <1005334931@qq.com>
Signed-off-by: wuweiqiang24 <wuweiqiang11@huawei.com>
2025-09-11 21:26:36 +08:00

113 lines
4.0 KiB
Python

# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.distributed
import torch.distributed as dist
import torch_npu
COMM_STREAM = None
def async_all_to_all(input_,
output_split_sizes,
input_split_sizes,
group,
event=None):
if output_split_sizes is None:
# Equal split (all2all)
a2a_out = torch.empty_like(input_)
else:
# Unequal split (all2all-v)
a2a_out = input_.new_empty(
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
dtype=input_.dtype,
device=torch.npu.current_device(),
)
if event:
# multi stream wait event
global COMM_STREAM
if COMM_STREAM is None:
COMM_STREAM = torch_npu.npu.Stream(
device=torch.npu.current_device())
with torch_npu.npu.stream(COMM_STREAM):
event.wait()
handle = dist.all_to_all_single(
a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
else:
handle = dist.all_to_all_single(a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
return input_, a2a_out, handle
def _gather_along_first_dim(input_, group, output_split_sizes=None):
"""Gather tensors and concatenate along the first dimension.
Args:
input_tensor (torch.Tensor):
A tensor to be gathered.
output_split_sizes (List[int], optional):
A list specifying the sizes of the output splits along the first dimension.
If None, equal splitting is assumed. Default: None.
Returns:
torch.Tensor: Gathered tensor.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output
def gather_from_sequence_parallel_region(
input_,
group,
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)