Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -3,14 +3,13 @@
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
# allow_nvlink_for_low_latency_mode=True,
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,

View File

@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
KiB = 1024
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
},
}
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
# while custom_AR wins for mid-range sizes.
#
# Benchmark results (8 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
# 32K - 64K: custom_AR wins
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
#
# Benchmark results (4 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
#
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
"custom_ar_preferred_ranges": {
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
},
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
"""
Determine if NCCL symmetric memory allreduce should be used.
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
- Small tensors (≤16K): Lower latency than custom_AR
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
Custom_AR is preferred for mid-range sizes where its P2P approach
has lower overhead than the symm_mem copy-in/copy-out pattern.
"""
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
)
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
tensor_size = input_tensor.nbytes
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
world_size
)
if custom_ar_range is not None:
lower_bound, upper_bound = custom_ar_range
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
# Use custom_AR (not symm_mem) for mid-range sizes
return tensor_size <= lower_bound or tensor_size >= upper_bound
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]

View File

@@ -30,8 +30,9 @@ class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
@@ -48,12 +49,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs):
# get a handle for the all2all communication,
@@ -122,17 +128,36 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
all2all_backend = None
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError

View File

@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
and hasattr(torch.ops._C, "init_shm_manager")
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
and self._all_group_ranks_share_shm_group_name()
):
self.dist_module = _CPUSHMDistributed(self)
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
logger.info(
"CPU SHM communicator disabled for group %s: ranks do not share "
"the same SHM group name, falling back to torch.distributed.",
unique_name,
)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def _all_group_ranks_share_shm_group_name(self) -> bool:
"""
CPUSHM requires all ranks in this group to agree on one SHM group name.
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
"""
local_name = _CPUSHMDistributed.make_group_name(self)
names: list[str] = [""] * self.world_size
torch.distributed.all_gather_object(
names,
local_name,
group=self.device_group,
)
return len(set(names)) == 1
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
self.communicator = communicator
self.group_name = self.make_group_name(communicator)
self.handle = self._init_cpu_shm()
@staticmethod
def make_group_name(communicator: CpuCommunicator) -> str:
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
group_ranks = [str(rank) for rank in communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(

View File

@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
import ixformer.distributed as ixfd
import os
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_flashinfer_allreduce = use_flashinfer_allreduce
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out
if self.world_size == 1:
return input_
if self.use_vllm_comm:
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# pynccl_comm.reduce_scatter(output, input_tensor)
torch.distributed.reduce_scatter_tensor(output,
input_tensor,
group=self.device_group)
# Perform reduce-scatter operation
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.send(tensor, dst)
# else:
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
if self.use_vllm_comm:
ixfd.send(tensor, self.ranks[dst], self.device_group)
else:
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
# pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.recv(tensor, src)
# else:
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
if self.use_vllm_comm:
ixfd.recv(tensor, self.ranks[src], self.device_group)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
# return self.all2all_manager.dispatch(
# hidden_states,
# topk_weights,
# topk_ids,
# is_sequence_parallel,
# extra_tensors=extra_tensors,
# )
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
hidden_states, extra_residual, router_logits)
return hidden_states, extra_residual, router_logits
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")

View File

@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()