Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,14 +3,13 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
self.handle_cache._cache.clear()
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
# allow_nvlink_for_low_latency_mode=True,
|
||||
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KiB = 1024
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
},
|
||||
}
|
||||
|
||||
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
|
||||
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
|
||||
# while custom_AR wins for mid-range sizes.
|
||||
#
|
||||
# Benchmark results (8 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
|
||||
# 32K - 64K: custom_AR wins
|
||||
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
|
||||
#
|
||||
# Benchmark results (4 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
|
||||
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
|
||||
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
|
||||
#
|
||||
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
|
||||
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
|
||||
"custom_ar_preferred_ranges": {
|
||||
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
|
||||
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
|
||||
},
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Determine if NCCL symmetric memory allreduce should be used.
|
||||
|
||||
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
|
||||
- Small tensors (≤16K): Lower latency than custom_AR
|
||||
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
|
||||
|
||||
Custom_AR is preferred for mid-range sizes where its P2P approach
|
||||
has lower overhead than the symm_mem copy-in/copy-out pattern.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
|
||||
tensor_size = input_tensor.nbytes
|
||||
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
|
||||
world_size
|
||||
)
|
||||
|
||||
if custom_ar_range is not None:
|
||||
lower_bound, upper_bound = custom_ar_range
|
||||
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
|
||||
# Use custom_AR (not symm_mem) for mid-range sizes
|
||||
return tensor_size <= lower_bound or tensor_size >= upper_bound
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
|
||||
@@ -30,8 +30,9 @@ class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
self.cpu_group = cpu_group
|
||||
self.tcp_store_group = tcp_store_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -48,12 +49,17 @@ class All2AllManagerBase:
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
if tcp_store_group is None:
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
else:
|
||||
self.internode = not all(
|
||||
in_the_same_node_as(tcp_store_group, source_rank=0)
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
@@ -122,17 +128,36 @@ class DeviceCommunicatorBase:
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
# Check if this is a stateless process group
|
||||
from torch.distributed.distributed_c10d import _world
|
||||
|
||||
is_stateless = _world.pg_map.get(cpu_group, None) is None
|
||||
|
||||
if is_stateless:
|
||||
# For stateless groups, we can't use torch.distributed methods
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
assert global_ranks is not None
|
||||
assert global_world_size is not None
|
||||
self.ranks = global_ranks
|
||||
self.global_rank = self.ranks[self.rank]
|
||||
self.global_world_size = global_world_size
|
||||
self.rank_in_group = self.rank
|
||||
else:
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
torch.distributed.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
and self._all_group_ranks_share_shm_group_name()
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
|
||||
logger.info(
|
||||
"CPU SHM communicator disabled for group %s: ranks do not share "
|
||||
"the same SHM group name, falling back to torch.distributed.",
|
||||
unique_name,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def _all_group_ranks_share_shm_group_name(self) -> bool:
|
||||
"""
|
||||
CPUSHM requires all ranks in this group to agree on one SHM group name.
|
||||
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
|
||||
"""
|
||||
local_name = _CPUSHMDistributed.make_group_name(self)
|
||||
names: list[str] = [""] * self.world_size
|
||||
torch.distributed.all_gather_object(
|
||||
names,
|
||||
local_name,
|
||||
group=self.device_group,
|
||||
)
|
||||
return len(set(names)) == 1
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
self.communicator = communicator
|
||||
|
||||
self.group_name = self.make_group_name(communicator)
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
@staticmethod
|
||||
def make_group_name(communicator: CpuCommunicator) -> str:
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
group_ranks = [str(rank) for rank in communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import StatelessProcessGroup
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
tcp_store_group: StatelessProcessGroup | None = None,
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
super().__init__(
|
||||
cpu_group,
|
||||
device,
|
||||
device_group,
|
||||
unique_name,
|
||||
global_ranks,
|
||||
global_world_size,
|
||||
)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
|
||||
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = NaiveAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = AgRsAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return out
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.use_vllm_comm:
|
||||
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
|
||||
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
# Perform reduce-scatter operation
|
||||
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.send(tensor, dst)
|
||||
# else:
|
||||
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.send(tensor, self.ranks[dst], self.device_group)
|
||||
else:
|
||||
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
# pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.recv(tensor, src)
|
||||
# else:
|
||||
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.recv(tensor, self.ranks[src], self.device_group)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.broadcast(tensor, src)
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
self.all2all_manager = None # type: ignore[assignment]
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
# return self.all2all_manager.dispatch(
|
||||
# hidden_states,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# is_sequence_parallel,
|
||||
# extra_tensors=extra_tensors,
|
||||
# )
|
||||
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, extra_residual, router_logits)
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.group_start()
|
||||
for op in p2p_ops:
|
||||
if op.op is torch.distributed.isend:
|
||||
self.send(op.tensor, op.group_peer, stream)
|
||||
elif op.op is torch.distributed.irecv:
|
||||
self.recv(op.tensor, op.group_peer, stream)
|
||||
|
||||
self.group_end()
|
||||
|
||||
Reference in New Issue
Block a user