# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.logger import init_logger from .base_device_communicator import DeviceCommunicatorBase logger = init_logger(__name__) class XpuCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: ProcessGroup, device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: if self.all2all_backend != "naive": logger.warning( "`%s` all2all manager is not supported on XPU. " "Falling back to `naive` all2all manager for XPU.", self.all2all_backend, ) self.all2all_backend = "naive" if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> torch.Tensor | None: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() # For xpu path, gather doesn't work properly together with ray # cluster so we use all_gather instead for now. input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty( (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device ) # All-gather. dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( input_size[:dim] + (self.world_size * input_size[dim],) + input_size[dim + 1 :] ) else: output_tensor = None return output_tensor def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits, is_sequence_parallel ) return hidden_states, router_logits def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine( hidden_states, is_sequence_parallel ) return hidden_states