# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # import torch import torch.distributed as dist from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase class NPUCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: dist.ProcessGroup, device: torch.device | None = None, device_group: dist.ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # init device according to rank self.device = torch.npu.current_device() def all_to_all( self, input_: torch.Tensor, scatter_dim: int = 0, gather_dim: int = -1, scatter_sizes: list[int] | None = None, gather_sizes: list[int] | None = None, ) -> torch.Tensor: if scatter_dim < 0: scatter_dim += input_.dim() if gather_dim < 0: gather_dim += input_.dim() if scatter_sizes is not None and gather_sizes is not None: input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)] output_list = [] tensor_shape_base = input_list[self.rank].size() for i in range(self.world_size): tensor_shape = list(tensor_shape_base) tensor_shape[gather_dim] = gather_sizes[i] output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device)) else: input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)] output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)] dist.all_to_all(output_list, input_list, group=self.device_group) output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() return output_tensor