[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
145
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.dist_module = torch.distributed
|
||||
|
||||
if (current_platform.get_cpu_architecture()
|
||||
== CpuArchEnum.X86) and hasattr(
|
||||
torch.ops._C,
|
||||
"init_shm_manager") and unique_name.startswith("tp"):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
handle,
|
||||
self.group_name,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(self,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(self.handle, input, gather_list,
|
||||
torch.distributed.get_group_rank(group, dst))
|
||||
|
||||
def all_gather_into_tensor(self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
Reference in New Issue
Block a user