# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch import torch.distributed as dist from torch.distributed import ProcessGroup from .base_device_communicator import DeviceCommunicatorBase class XpuCommunicator(DeviceCommunicatorBase): def __init__(self, cpu_group: ProcessGroup, device: Optional[torch.device] = None, device_group: Optional[ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") if dim < 0: # Convert negative dim to positive. dim += input_.dim() # For xpu path, gather doesn't work properly together with ray # cluster so we use all_gather instead for now. input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty((self.world_size, ) + input_size, dtype=input_.dtype, device=input_.device) # All-gather. dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + (self.world_size * input_size[dim], ) + input_size[dim + 1:]) else: output_tensor = None return output_tensor