v1.0
This commit is contained in:
95
distributed/device_communicators/xpu_communicator.py
Normal file
95
distributed/device_communicators/xpu_communicator.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive":
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU. "
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
self.all2all_backend,
|
||||
)
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def all_reduce(self, input_) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
)
|
||||
return hidden_states
|
||||
Reference in New Issue
Block a user