Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -3,14 +3,13 @@
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
# allow_nvlink_for_low_latency_mode=True,
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,

View File

@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
KiB = 1024
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
},
}
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
# while custom_AR wins for mid-range sizes.
#
# Benchmark results (8 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
# 32K - 64K: custom_AR wins
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
#
# Benchmark results (4 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
#
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
"custom_ar_preferred_ranges": {
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
},
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
"""
Determine if NCCL symmetric memory allreduce should be used.
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
- Small tensors (≤16K): Lower latency than custom_AR
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
Custom_AR is preferred for mid-range sizes where its P2P approach
has lower overhead than the symm_mem copy-in/copy-out pattern.
"""
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
)
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
tensor_size = input_tensor.nbytes
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
world_size
)
if custom_ar_range is not None:
lower_bound, upper_bound = custom_ar_range
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
# Use custom_AR (not symm_mem) for mid-range sizes
return tensor_size <= lower_bound or tensor_size >= upper_bound
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]

View File

@@ -30,8 +30,9 @@ class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
@@ -48,12 +49,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs):
# get a handle for the all2all communication,
@@ -122,17 +128,36 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
all2all_backend = None
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError

View File

@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
and hasattr(torch.ops._C, "init_shm_manager")
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
and self._all_group_ranks_share_shm_group_name()
):
self.dist_module = _CPUSHMDistributed(self)
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
logger.info(
"CPU SHM communicator disabled for group %s: ranks do not share "
"the same SHM group name, falling back to torch.distributed.",
unique_name,
)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def _all_group_ranks_share_shm_group_name(self) -> bool:
"""
CPUSHM requires all ranks in this group to agree on one SHM group name.
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
"""
local_name = _CPUSHMDistributed.make_group_name(self)
names: list[str] = [""] * self.world_size
torch.distributed.all_gather_object(
names,
local_name,
group=self.device_group,
)
return len(set(names)) == 1
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
self.communicator = communicator
self.group_name = self.make_group_name(communicator)
self.handle = self._init_cpu_shm()
@staticmethod
def make_group_name(communicator: CpuCommunicator) -> str:
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
group_ranks = [str(rank) for rank in communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(

View File

@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
import ixformer.distributed as ixfd
import os
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_flashinfer_allreduce = use_flashinfer_allreduce
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out
if self.world_size == 1:
return input_
if self.use_vllm_comm:
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# pynccl_comm.reduce_scatter(output, input_tensor)
torch.distributed.reduce_scatter_tensor(output,
input_tensor,
group=self.device_group)
# Perform reduce-scatter operation
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.send(tensor, dst)
# else:
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
if self.use_vllm_comm:
ixfd.send(tensor, self.ranks[dst], self.device_group)
else:
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
# pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.recv(tensor, src)
# else:
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
if self.use_vllm_comm:
ixfd.recv(tensor, self.ranks[src], self.device_group)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
# return self.all2all_manager.dispatch(
# hidden_states,
# topk_weights,
# topk_ids,
# is_sequence_parallel,
# extra_tensors=extra_tensors,
# )
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
hidden_states, extra_residual, router_logits)
return hidden_states, extra_residual, router_logits
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")

View File

@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()

View File

View File

@@ -0,0 +1,529 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
CompilationMode,
set_current_vllm_config,
)
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
create_standby_groups,
get_standby_dp_group,
get_standby_ep_group,
pop_standby_groups,
)
from vllm.distributed.parallel_state import (
_replace_active_groups,
prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
logger = init_logger(__name__)
def batch_transfer_weights(
model: nn.Module,
is_sender: bool,
peer_rank: int,
dp_group: StatelessGroupCoordinator,
expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
device_comm = dp_group.device_communicator
if device_comm is None:
raise ValueError("No device communicator found")
expert_weights_set = set()
for weight_group in expert_weights:
for weight in weight_group:
expert_weights_set.add(weight.data_ptr())
state_dict = model.state_dict()
all_params = []
for name, param in state_dict.items():
if name.endswith("expert_map"):
continue
if param.data_ptr() not in expert_weights_set:
all_params.append(param.data)
assert len(all_params) > 0
p2p_ops = []
for param in all_params:
op = object.__new__(P2POp)
if is_sender:
op.op = torch.distributed.isend
op.tensor = param
else:
op.op = torch.distributed.irecv
op.tensor = param
op.group_peer = peer_rank
p2p_ops.append(op)
device_comm.batch_isend_irecv(p2p_ops)
def broadcast_expert_mapping(
physical_to_logical: torch.Tensor | None,
num_local_physical_experts: int | None,
num_logical_experts: int | None,
dp_group: StatelessGroupCoordinator,
device: torch.device,
src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
if dp_group.rank_in_group == src_rank:
assert physical_to_logical is not None
assert num_local_physical_experts is not None
assert num_logical_experts is not None
assert physical_to_logical.dtype == torch.int64
shape_tensor = torch.tensor(
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
)
metadata_tensor = torch.tensor(
[num_local_physical_experts, num_logical_experts],
dtype=torch.int64,
device="cpu",
)
else:
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
if dp_group.rank_in_group != src_rank:
assert device is not None
physical_to_logical = torch.empty(
tuple(shape_tensor.tolist()),
dtype=torch.int64,
device=device,
)
assert physical_to_logical is not None
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
num_local_physical_experts = int(metadata_tensor[0].item())
num_logical_experts = int(metadata_tensor[1].item())
return physical_to_logical, num_local_physical_experts, num_logical_experts
class ElasticEPScalingExecutor:
def __init__(self, worker):
self.worker_ref = weakref.ref(worker)
self.reconfig_request = None
@property
def worker(self):
worker = self.worker_ref()
if worker is None:
raise RuntimeError("Worker has been garbage collected")
return worker
def execute(self, execute_method: str, *args, **kwargs):
method = getattr(self, execute_method, None)
if method is None:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
updated_config.parallel_config = copy.deepcopy(
self.worker.vllm_config.parallel_config
)
updated_config.parallel_config.data_parallel_size = new_dp_size
with set_current_vllm_config(updated_config):
create_standby_groups(
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
# Broadcast old_dp_size to all workers in standby group
if standby_dp_group.rank_in_group < old_dp_size:
old_dp_size_tensor = torch.tensor(
[old_dp_size], dtype=torch.int64, device="cpu"
)
else:
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
old_dp_size_tensor, 0
)
num_new_workers = new_dp_size - old_dp_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Sender-receiver pairing: the first new_workers % old_dp_size
# senders get (k+1) contiguous receivers, the rest get k
# receivers.
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if dp_rank < remainder:
recv_begin = dp_rank * (num_dst_per_sender + 1)
recv_end = recv_begin + num_dst_per_sender + 1
else:
recv_begin = (
remainder * (num_dst_per_sender + 1)
+ (dp_rank - remainder) * num_dst_per_sender
)
recv_end = recv_begin + num_dst_per_sender
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
model = self.worker.model_runner.get_model()
for new_worker_rank in sorted(ranks_to_send):
batch_transfer_weights(
model=model,
is_sender=True,
peer_rank=new_worker_rank,
dp_group=standby_dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
model_config = self.worker.model_runner.model_config
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
physical_to_logical = eplb_model_state.physical_to_logical_map
num_physical_experts = physical_to_logical.shape[1]
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
broadcast_expert_mapping(
physical_to_logical=physical_to_logical,
num_local_physical_experts=num_local_physical_experts,
num_logical_experts=num_logical_experts,
dp_group=standby_dp_group,
src_rank=0,
device=self.worker.device,
)
def switch_and_remove(self) -> None:
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
reconfig_request = self.reconfig_request
assert reconfig_request is not None
new_dp_size = reconfig_request.new_data_parallel_size
new_ep_size = get_ep_group().world_size
parallel_config.data_parallel_size = new_dp_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
# Reconfigure MoE modules with new EP size
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
# Update EPLB state
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
num_physical_experts = num_local_experts * new_ep_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
old_physical_to_logical = eplb_model_state.physical_to_logical_map
num_moe_layers = old_physical_to_logical.shape[0]
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
if new_dp_size > old_dp_size:
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_experts * new_ep_size),
-1,
dtype=old_physical_to_logical.dtype,
device=old_physical_to_logical.device,
)
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
old_physical_to_logical
)
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
pad_size = num_physical_experts - old_num_physical_experts
if new_dp_size > old_dp_size:
assert pad_size > 0
expanded_expert_load_pass = F.pad(
eplb_model_state.expert_load_pass, (0, pad_size), value=0
)
expanded_expert_load_window = F.pad(
eplb_model_state.expert_load_window, (0, pad_size), value=0
)
eplb_model_state.expert_load_pass = expanded_expert_load_pass
eplb_model_state.expert_load_window = expanded_expert_load_window
eplb_state.num_valid_physical_experts = old_num_physical_experts
else:
assert pad_size < 0
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
:, :num_physical_experts
]
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
:, :, :num_physical_experts
]
eplb_state.num_valid_physical_experts = num_physical_experts
model = self.worker.model_runner.get_model()
model.expert_weights = []
with set_current_vllm_config(self.worker.vllm_config):
model.set_eplb_state(
eplb_model_state.expert_load_pass,
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
)
# Force re-creation of the modular kernel (and all2all manager)
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model)
if (
self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
# NOTE(yongji): when using stock torch.compile,
# torch.compile is triggered during GPUModelRunner's load_model()
# TODO(yongji):check do we need to re-trigger torch.compile here?
# any changes to the tensor shapes in execution should already
# be handled internally by torch.compile.
backend = self.worker.vllm_config.compilation_config.init_backend(
self.worker.vllm_config
)
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
saved_block_tables.append(
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
for bt, (saved_gpu, saved_cpu) in zip(
multi_block_table.block_tables, saved_block_tables
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = (
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
new_dp_size = dp_group.world_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Receive old_dp_size broadcasted during transfer_weights
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
old_dp_size = int(old_dp_size_tensor[0].item())
# Calculate which existing worker will send to this new worker
num_new_workers = new_dp_size - old_dp_size
new_worker_idx = dp_rank - old_dp_size
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if new_worker_idx < remainder * (num_dst_per_sender + 1):
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
else:
sender_rank = (
remainder
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
// num_dst_per_sender
)
model = self.worker.model_runner.get_model()
batch_transfer_weights(
model=model,
is_sender=False,
peer_rank=sender_rank,
dp_group=dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
physical_to_logical, num_local_physical_experts, num_logical_experts = (
broadcast_expert_mapping(
physical_to_logical=None,
num_local_physical_experts=None,
num_logical_experts=None,
dp_group=dp_group,
src_rank=0,
device=self.worker.device,
)
)
num_moe_layers = physical_to_logical.shape[0]
new_dp_size = get_dp_group().world_size
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
new_ep_size = new_dp_size * tp_size
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_physical_experts * new_ep_size),
-1,
dtype=physical_to_logical.dtype,
device=physical_to_logical.device,
)
old_num_physical_experts = physical_to_logical.shape[1]
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
return (
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
)
def prepare_new_worker(self) -> None:
with set_current_vllm_config(self.worker.vllm_config):
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())

View File

@@ -0,0 +1,563 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
import torch.distributed
from vllm.config import ParallelConfig
from vllm.distributed import (
sched_yield,
stateless_destroy_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.v1.engine import (
EEPNotificationType,
ReconfigureDistributedRequest,
ReconfigureRankType,
)
from vllm.v1.engine.core import DPEngineCoreProc
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
WorkerType = Literal["existing", "new", "removing"]
class ScaleUpExistingEngineState(enum.IntEnum):
WAIT_NEW_CORE_ENGINES_INIT = 0
CREATE_STANDBY_GROUPS = 1
TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY_SIZE = 5
SWITCH_AND_PREPARE = 6
EPLB_RESHUFFLE = 7
COMPLETE = 8
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class ScaleDownRemainingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
SWITCH_AND_PREPARE = 2
COMPLETE = 3
class ScaleDownRemovingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
in the first stage of our two-staged
TCPStore based barrier to synchronize the
execution of all engines in the DP group.
"""
class ElasticEPScalingState:
def __init__(
self,
model_executor: "Executor",
engine_core: "DPEngineCoreProc",
vllm_config: "VllmConfig",
new_parallel_config: ParallelConfig,
worker_type: WorkerType,
scale_type: Literal["scale_up", "scale_down"],
reconfig_request: ReconfigureDistributedRequest | None = None,
):
self.model_executor_ref = weakref.ref(model_executor)
self.engine_core_ref = weakref.ref(engine_core)
self.vllm_config = vllm_config
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
else:
self.state = (
ScaleDownRemovingEngineState.PREPARE
if worker_type == "removing"
else ScaleDownRemainingEngineState.PREPARE
)
@property
def model_executor(self) -> "Executor":
model_executor = self.model_executor_ref()
if model_executor is None:
raise RuntimeError("Model executor has been garbage collected")
return model_executor
@property
def engine_core(self) -> "DPEngineCoreProc":
engine_core = self.engine_core_ref()
if engine_core is None:
raise RuntimeError("Engine core has been garbage collected")
return engine_core
def progress(self) -> bool:
if self.scale_type == "scale_up":
return (
self._progress_new_engine()
if self.worker_type == "new"
else self._progress_existing_engine()
)
return (
self._progress_removing_engine()
if self.worker_type == "removing"
else self._progress_remaining_engine()
)
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
arrival_key = f"arrival_{barrier_id}_{group_rank}"
dp_store.set(arrival_key, b"1")
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < group_size:
if (
timeout is not None
and time.time() - start_time > timeout.total_seconds()
):
raise _BarrierTimeoutError(
f"Barrier timed out after {timeout.total_seconds()} seconds"
)
for i in range(group_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
present = dp_store.check([key])
if present:
processes_arrived.add(i)
if len(processes_arrived) < group_size:
sched_yield()
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
"""
Execute a two-staged barrier to synchronize all engines in the DP group.
Some DP EngineCores may receive the reconfiguration notifications
later than others, and already proceed to engine step (model forward)
in the busy loop.
In this case, EngineCores that already proceed to reconfiguration
should skip reconfiguration and execute model forward for one more
step, so in the next step, all EngineCores will be synchronized.
We use a two-staged barrier to achieve this. The first time each
EngineCore executes the barrier, if a timeout is reached before the
barrier completes, that means some EngineCores have already entered
engine step. The EngineCores that timed out will then proceed to
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
barrier_id = f"eep_barrier_{barrier_name}"
sync_key = f"{barrier_id}_sync"
# TODO(yongji): figure out appropriate timeout for the barrier
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
try:
self._execute_tcp_store_barrier(
dp_store, group_rank, group_size, barrier_id, timeout=timeout
)
torch.distributed.barrier(dp_group)
if group_rank == 0:
dp_store.delete_key(sync_key)
for i in range(group_size):
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
return True
except _BarrierTimeoutError as e:
if timeout is None:
raise RuntimeError("Unexpected timeout encountered") from e
dp_store.compare_set(sync_key, "", b"1")
return False
def _progress_existing_engine(self) -> bool:
state = self.state
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all existing workers to receive the request
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="create_standby_groups"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._create_standby_groups()
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
return True
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
self._transfer_expert_mapping()
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
return True
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
return False
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="transfer_weights"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._transfer_weights()
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
return True
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
self._sync_kv_cache_memory_size()
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
return True
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
if self.new_dp_group.rank() == 0:
self.new_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle()
self.state = ScaleUpExistingEngineState.COMPLETE
self._update_parallel_config()
return True
else:
assert self.state == ScaleUpExistingEngineState.COMPLETE
return True
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group,
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
assert self.new_dp_group.rank() > 0
self._eplb_reshuffle()
self.state = ScaleUpNewEngineState.COMPLETE
return True
else:
assert self.state == ScaleUpNewEngineState.COMPLETE
return True
def _progress_remaining_engine(self) -> bool:
state = self.state
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True
else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True
def _progress_removing_engine(self) -> bool:
state = self.state
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._switch_and_remove()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
elif (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
and self.state
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
def is_complete(self) -> bool:
if self.scale_type == "scale_up":
return (
self.state == ScaleUpNewEngineState.COMPLETE
if self.worker_type == "new"
else self.state == ScaleUpExistingEngineState.COMPLETE
)
return (
self.state == ScaleDownRemovingEngineState.COMPLETE
if self.worker_type == "removing"
else self.state == ScaleDownRemainingEngineState.COMPLETE
)
def _create_standby_groups(self):
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
def _switch_and_prepare(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_prepare",)
)
old_dp_group = self.old_dp_group
stateless_destroy_torch_distributed_process_group(old_dp_group)
assert self.new_dp_group is not None
new_dp_group = self.new_dp_group
self.engine_core.dp_group = new_dp_group
self.engine_core.dp_rank = new_dp_group.rank()
self.engine_core.dp_store = self.new_dp_store
engines_running = int(self.engine_core.engines_running)
current_wave = self.engine_core.current_wave
step_counter = self.engine_core.step_counter
tensor = torch.tensor(
[engines_running, current_wave, step_counter],
dtype=torch.int32,
device="cpu",
)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
if new_dp_group.rank() == 0:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.RECONFIGURE_FINISHED
)
logger.info("[Elastic EP] Switched to new setup")
def _eplb_reshuffle(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
)
assert self.new_dp_group is not None
if self.new_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _switch_and_remove(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_remove",)
)
def _update_parallel_config(self):
assert self.reconfig_request is not None
reconfig_request = self.reconfig_request
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result

View File

@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)

View File

@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count,
in_the_same_node_as,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@@ -159,7 +160,7 @@ class EplbModelState:
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
across different dispatch methods (naive all-to-all, DeepEP).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
@@ -302,6 +303,14 @@ class EplbState:
"""
CUDA device index for the async EPLB worker thread.
"""
self.num_valid_physical_experts: int = 0
"""
Number of valid physical experts.
This is the number of physical experts that are
actually mapped to logical experts. In elastic EP,
newly started EP ranks may not have physical experts
mapped yet.
"""
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
@@ -367,9 +376,6 @@ class EplbState:
self,
model: MixtureOfExperts,
model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
):
"""
Build the initial EPLB state.
@@ -462,75 +468,15 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
model.num_moe_layers,
model.num_logical_experts,
)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
else:
new_physical_to_logical_map = None
new_logical_to_physical_map = None
new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]),
),
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
new_logical_to_physical_map=new_logical_to_physical_map,
new_logical_replica_count=new_logical_replica_count,
new_physical_to_logical_map=None,
new_logical_to_physical_map=None,
new_logical_replica_count=None,
)
self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts
def step(
self,
@@ -696,8 +643,6 @@ class EplbState:
def rearrange(
self,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
@@ -734,67 +673,34 @@ class EplbState:
"(profile)" if is_profile else "",
)
if global_expert_loads is None:
# Map the physical expert load to global logical experts
global_expert_load_windows = []
if not execute_shuffle:
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_models, group=get_ep_group().cpu_group, group_src=0
)
for eplb_model_state in self.model_states.values():
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
# Map the physical expert load to global logical experts
global_expert_load_windows = []
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_windows = global_expert_loads
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=expert_load_window,
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
coordinator = get_ep_group()
assert isinstance(coordinator, StatelessGroupCoordinator)
tcp_store_group = coordinator.tcp_store_group
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = (
num_replicas // ep_group.size() * num_gpus
@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None:
self.async_worker = start_async_worker(
self,
rank_mapping=rank_mapping,
is_profile=is_profile,
)
@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
@classmethod
def from_mapping(
cls,
model: MixtureOfExperts,
model_config: ModelConfig,
device: torch.device,
parallel_config: ParallelConfig,
expanded_physical_to_logical: torch.Tensor,
num_valid_physical_experts: int,
) -> "EplbState":
eplb_state = cls(
parallel_config=parallel_config,
device=device,
)
eplb_state.add_model(
model=model,
model_config=model_config,
)
eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full(
(
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
)
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device)
logical_replica_count = logical_replica_count.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state
@dataclass
class EplbLayerState:

View File

@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank,
)
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping
# Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends
if send_count > 0:
@@ -284,15 +294,23 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
if is_stateless:
for w in expert_weights:
op = object.__new__(P2POp)
op.op = torch.distributed.isend
op.tensor = w[src]
op.group_peer = dst
p2p_ops.append(op)
else:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# 3. Post recvs
if recv_count > 0:
@@ -321,26 +339,40 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
if is_stateless:
for b in expert_weights_buffers:
op = object.__new__(P2POp)
op.op = torch.distributed.irecv
op.tensor = b[dst]
op.group_peer = src
p2p_ops.append(op)
else:
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# wait for the communication to finish
return (
is_unchanged,

View File

@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
def clear_events(self) -> None:
raise NotImplementedError
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
self.add_events(other.get_all_events())
return self
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism

View File

@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
"ExampleConnector",
)
KVConnectorFactory.register_connector(
"ExampleHiddenStatesConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
"ExampleHiddenStatesConnector",
)
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",

View File

@@ -413,7 +413,20 @@ class TpKVTopology:
f"by local tensor parallel size {self.tp_size}."
)
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
return remote_tp_size // self.tp_size
def pp_ratio(
self,
remote_pp_size: int,
) -> int:
"""
Calculate the pipeline parallel ratio between local and remote PP.
"""
assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
f"Local pipline parallel size {self.tp_size} is not divisible "
f"by remote pipline parallel size {remote_pp_size} or vice versa."
)
return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
def block_size_ratio(
self,
@@ -457,6 +470,7 @@ class TpKVTopology:
def get_target_remote_ranks(
self,
remote_tp_size: int,
remote_pp_size: int
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
@@ -464,19 +478,36 @@ class TpKVTopology:
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
pp_ratio = self.pp_ratio(remote_pp_size)
target_pp_rank_list = []
target_tp_rank_list = []
if self.pp_size < remote_pp_size:
for i in range(pp_ratio):
target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
else:
target_pp_rank_list.append(self.pp_rank // pp_ratio)
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
if self.tp_size < remote_tp_size:
for i in range(tp_ratio):
target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
else:
target_tp_rank_list.append(self.tp_rank // tp_ratio)
target_rank_list = []
for pp_rank in target_pp_rank_list:
for tp_rank in target_tp_rank_list:
target_rank = pp_rank * remote_tp_size + tp_rank
target_rank_list.append((target_rank, pp_rank, tp_rank))
return target_rank_list
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size)
remote_pp_size = self.remote_pp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
def get_current_attn_backend(vllm_config: VllmConfig):

View File

@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
)
return None
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
Check if this connector requires PIECEWISE CUDA graph mode.
Connectors that use asynchronous layer-by-layer operations
(wait_for_layer_load/save_kv_layer) should override this method
to return True when those operations are enabled. These operations
cannot be captured in CUDA graphs and will be skipped during replay,
causing data races. PIECEWISE mode allows Python code to execute
between graph pieces, ensuring proper synchronization.
Args:
extra_config: The kv_connector_extra_config dict from
KVTransferConfig.
Returns:
True if this connector requires PIECEWISE CUDA graph mode,
False otherwise.
"""
return False
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations

View File

@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> None:
"""Inject the KV cache into the layer.
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
kv_cache,
request.slot_mapping,
attn_metadata[layer_name],
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
return layer[block_idxs, :, offsets]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]

View File

@@ -0,0 +1,354 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
def extract_from_kv_cache(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Extract data from KV cache
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
"""
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
# shape: [len(slot_mapping), num_heads, head_size]
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
@dataclass
class ReqMeta:
# Request ID
req_id: str
# Request filename
filename: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Whether this request is a new request or partially computed already
new_req: bool
@staticmethod
def make_meta(
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool,
) -> "ReqMeta":
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()
return ReqMeta(
req_id=req_id,
filename=filename,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
new_req=new_req,
)
@dataclass
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool = True,
) -> None:
self.requests.append(
ReqMeta.make_meta(
req_id, filename, token_ids, block_ids, block_size, new_req
)
)
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
"""
Simple debug implementation of a HiddenStatesConnector.
Simply extracts the hidden states from the kv cache and stores them to disk.
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
"""
@property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
# Must be False so that drafter kv cache isn't merged with verifier's
return False
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
self.cache_layers: list[str] = [] # set by self.register_kv_caches
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
assert self._vllm_config.speculative_config is not None, (
"ExampleHiddenStatesConnector only works when using "
"'extract_hidden_states' speculative method"
)
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
self.num_hidden_states = len(
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
)
self._request_filenames: dict[str, str] = {}
self._active_requests: dict[str, NewRequestData] = {}
self._req_blocks: dict[str, list[int]] = {}
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, *args, **kwargs: Any) -> None:
pass # Empty implementation of abstract method
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Empty implementation of abstract method
def wait_for_save(self):
pass # Empty implementation of abstract method
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionLayer,
)
# Filter layers to only include CacheOnlyAttentionLayers
layers = get_layers_from_vllm_config(
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
)
self.cache_layers = list(layers.keys())
assert len(self.cache_layers) == 1, (
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
if layer_name not in self.cache_layers:
return
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionMetadata,
)
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
)
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
os.makedirs(self._storage_path, exist_ok=True)
for request in connector_metadata.requests:
hidden_states = extract_from_kv_cache(
kv_layer, request.slot_mapping, request.token_ids.shape[0]
)
tensors = {
"hidden_states": hidden_states.detach().cpu(),
"token_ids": request.token_ids.detach().cpu(),
}
safetensors.torch.save_file(tensors, request.filename)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# This connector is store-only, so we don't need to load any tokens
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
# Usually used to handle allocation of new blocks for requests that are loading
# tokens from connector's external kv cache. We never load from external cache
# so this is a no-op.
assert num_external_tokens == 0, "This connector is store-only"
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ExampleHiddenStatesConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
meta.add_request(
new_req.req_id,
filename=filename,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
self._request_filenames[new_req.req_id] = filename
self._active_requests[new_req.req_id] = new_req
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
if req_id not in self._active_requests:
continue
new_block_ids = cached_reqs.new_block_ids[i]
cached_req = self._active_requests[req_id]
req_block_ids = self._req_blocks[req_id]
assert new_block_ids is not None
block_ids = new_block_ids[0]
req_block_ids.extend(block_ids)
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
meta.add_request(
req_id=req_id,
filename=filename,
token_ids=cached_req.prompt_token_ids or [],
block_ids=req_block_ids,
block_size=self._block_size,
new_req=False,
)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
req_filename = self._request_filenames.pop(req_id, None)
_ = self._active_requests.pop(req_id, None)
_ = self._req_blocks.pop(req_id, None)
return False, {"hidden_states_path": req_filename}
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
# NHD means we have (num_tokens, num_heads)
# HND means we have (num_heads, num_tokens)
# For now, we only support NHD layout since this keeps the
# hidden states for each token together in memory.
# HND is primarily used when sharding heads across devices.
return "NHD"

View File

@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
class LMCacheConnectorV1(KVConnectorBase_V1):
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
LMCache requires PIECEWISE CUDA graph mode when layerwise
operations are enabled. The wait_for_layer_load and save_kv_layer
methods perform actual async synchronization that cannot be
captured in CUDA graphs.
"""
return extra_config.get("use_layerwise", False)
def __init__(
self,
vllm_config: "VllmConfig",

View File

@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once(
"MooncakeConnector setting KV cache layout to HND for better xfer performance."
)
return "HND"
############################################################
# Scheduler Side Methods
############################################################
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
cache_layout = get_kv_cache_layout()
if cache_layout == "HND":
kernel_block_size = cache.shape[-2]
else:
kernel_block_size = cache.shape[-3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)

View File

@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
MultiConnector requires PIECEWISE CUDA graph mode if any of its
child connectors require it.
"""
connectors_config = extra_config.get("connectors", [])
for conn_config in connectors_config:
temp_ktc = KVTransferConfig(**conn_config)
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
child_extra_config = conn_config.get("kv_connector_extra_config", {})
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
return True
return False
def __init__(
self,
vllm_config: "VllmConfig",

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
import contextlib
import gc
import os
import pickle
import weakref
from collections import namedtuple
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
import ixformer.distributed as ixfd
import vllm._custom_ops as ops
@@ -327,6 +332,8 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
self_device_group = None
self_cpu_group = None
@@ -339,7 +346,7 @@ class GroupCoordinator:
with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ixfd_group = ixfd.init_comm_with_store(device_group)
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
@@ -372,8 +379,7 @@ class GroupCoordinator:
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
# device_group=self.device_group,
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
unique_name=self.unique_name,
)
@@ -385,11 +391,6 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6
)
from vllm.platforms import current_platform
# self.use_custom_op_call = (
# current_platform.is_cuda_alike() or current_platform.is_tpu()
# )
self.use_custom_op_call = False
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
@@ -468,14 +469,12 @@ class GroupCoordinator:
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
# from vllm.distributed.device_communicators.cuda_communicator import (
# CudaCommunicator,
# )
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
if self.device_communicator is not None:
# assert isinstance(self.device_communicator, CudaCommunicator)
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
@@ -608,9 +607,9 @@ class GroupCoordinator:
src=self.ranks[src],
group=self.device_group)
else:
torch.distributed.broadcast(input_,
src=self.ranks[src],
group=self.device_group)
torch.distributed.broadcast(
input_, src=self.ranks[src], group=self.device_group
)
return input_
def broadcast_object(self, obj: Any | None = None, src: int = 0):
@@ -764,10 +763,9 @@ class GroupCoordinator:
group=group,
async_op=True)
else:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
@@ -802,10 +800,8 @@ class GroupCoordinator:
async_op=True)
else:
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
@@ -876,6 +872,10 @@ class GroupCoordinator:
if self.world_size <= 1:
return []
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -893,10 +893,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.send_object(metadata_list, dst=dst)
@@ -917,6 +913,7 @@ class GroupCoordinator:
handle = torch.distributed.isend(
tensor, dst=self.ranks[dst], group=comm_group
)
if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))
handles.append(handle)
@@ -973,6 +970,11 @@ class GroupCoordinator:
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -990,10 +992,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if hasattr(self, "device_group"):
# torch.distributed.destroy_process_group(self.device_group)
if self.device_group is not None:
if self.device_communicator and self.device_communicator.use_vllm_comm:
ixfd.destroy_process_group(self.device_group)
else:
torch.distributed.destroy_process_group(self.device_group)
del self.device_group
if hasattr(self, "cpu_group"):
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
del self.cpu_group
if self.device_communicator is not None:
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits(
hidden_states,
extra_residual,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, extra_residual, router_logits
return hidden_states, router_logits
def dispatch(
self,
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
)
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective — all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
from vllm.config import get_current_vllm_config
config = get_current_vllm_config_or_none()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
@@ -1523,14 +1655,27 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
@@ -1542,9 +1687,22 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
# Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications
@@ -1557,10 +1715,25 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
)
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,

View File

@@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
_get_unique_name,
_register_group,
_split_tensor_dict,
)
from vllm.distributed.utils import (
StatelessProcessGroup,
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class StatelessGroupCoordinator(GroupCoordinator):
"""
A stateless version of the GroupCoordinator class in parallel_state,
It will create CPU, device and TCPStore based communication groups
that are independent of PyTorch's WORLD group. Hence,
communication groups with a different set of participants GPUs
can be created without destroying the existing ones.
"""
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = global_rank
self.local_rank = local_rank
self_device_group = None
self_cpu_group = None
self_tcp_store_group = None
from vllm.platforms import current_platform
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
device_group = stateless_init_torch_distributed_process_group(
host=host,
port=device_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
port=cpu_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
)
self_device_group = device_group
self_cpu_group = cpu_group
self_tcp_store_group = tcp_store_group
assert self_cpu_group is not None
assert self_device_group is not None
assert self_tcp_store_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.tcp_store_group = self_tcp_store_group
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls()
)
assert device_comm_cls == CudaCommunicator
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
global_ranks=self.ranks,
global_world_size=global_world_size,
tcp_store_group=self.tcp_store_group,
)
self.mq_broadcaster = None
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
)
self.use_cpu_custom_send_recv = False
def destroy(self):
if self.device_communicator:
self.device_communicator.destroy()
if self.device_group:
stateless_destroy_torch_distributed_process_group(self.device_group)
if self.cpu_group:
stateless_destroy_torch_distributed_process_group(self.cpu_group)
def size(self) -> int:
"""Return the world size of this group."""
return self.world_size
def broadcast(self, input_: torch.Tensor, src: int = 0):
if self.world_size == 1:
return input_
if self.device_communicator and input_.is_cuda:
return self.device_communicator.broadcast(input_, src)
else:
return self.tcp_store_group.broadcast(input_, src)
def broadcast_object(self, obj=None, src: int = 0):
if self.world_size == 1:
return obj
return self.tcp_store_group.broadcast_obj(obj, src)
def broadcast_object_list(
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
):
assert src < self.world_size
if self.world_size == 1:
return obj_list
if self.rank_in_group == src:
for obj in obj_list:
self.tcp_store_group.broadcast_obj(obj, src)
else:
for i in range(len(obj_list)):
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
return obj_list
def broadcast_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
src: int = 0,
group: ProcessGroup | None = None,
metadata_group: ProcessGroup | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if self.rank_in_group == src:
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
else:
metadata_list = None
tensor_list = []
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
metadata_list, src
)
if self.rank_in_group != src:
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
tensor_list.append(tensor)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
tensor.copy_(self.device_communicator.broadcast(tensor, src))
else:
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
return tensor_dict
def send_object(self, obj, dst: int) -> None:
assert dst < self.world_size
assert dst != self.rank_in_group
self.tcp_store_group.send_obj(obj, dst)
def recv_object(self, src: int):
assert src < self.world_size
assert src != self.rank_in_group
return self.tcp_store_group.recv_obj(src)
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.tcp_store_group.send_obj(metadata_list, dst)
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
self.device_communicator.send(tensor, dst)
else:
self.tcp_store_group.send(tensor, dst)
return None
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return None
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size
recv_metadata_list = self.tcp_store_group.recv_obj(src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() > 0:
if self.device_communicator and tensor.is_cuda:
tensor = self.device_communicator.recv(
tensor.size(), tensor.dtype, src
)
else:
tensor = self.tcp_store_group.recv(tensor, src)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
self.tcp_store_group.barrier()
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
if self.world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
if self.rank_in_group == dst:
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
gathered_list[self.rank_in_group] = input_
for src_rank in range(self.world_size):
if src_rank != self.rank_in_group:
gathered_list[src_rank] = self.device_communicator.recv(
input_.size(), input_.dtype, src_rank
)
return torch.cat(gathered_list, dim=dim)
else:
self.device_communicator.send(input_, dst)
return None

View File

@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
host: str,
port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
if backend == "gloo":
pg = init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
else:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:

View File

@@ -3,7 +3,7 @@
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
This should be called when the worker is shutting down.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | Any,
) -> None:
"""
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process
to send weights to all inference workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
The tensors should be on the appropriate device for the backend.
trainer_args: Dictionary containing backend-specific arguments needed
to send weights. The structure depends on the backend:
- NCCL: Contains 'group', 'src', 'packed', etc.
- IPC: Contains 'mode' ('http' or 'ray'),
'llm_handle' (for Ray), 'url' (for HTTP), etc.
Example:
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> engine.trainer_send_weights(param_iter, trainer_args)
"""
raise NotImplementedError

View File

@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)
WeightTransferEngineFactory.register_engine(
"ipc",
"vllm.distributed.weight_transfer.ipc_engine",
"IPCWeightTransferEngine",
)

View File

@@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""IPC-based weight transfer engine using CUDA IPC for communication."""
import base64
import pickle
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from typing import Any
import requests
import torch
from torch.multiprocessing.reductions import reduce_tensor
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
@dataclass
class IPCTrainerSendWeightsArgs:
"""Arguments for IPC trainer_send_weights method."""
mode: str
"""Transport mode: 'http' or 'ray'."""
llm_handle: Any = None
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
url: str | None = None
"""Base URL for HTTP endpoint (required for 'http' mode)."""
def __post_init__(self):
"""Validate that required arguments are provided for the selected mode."""
if self.mode == "ray" and self.llm_handle is None:
raise ValueError("llm_handle is required for 'ray' mode")
if self.mode == "http" and self.url is None:
raise ValueError("url is required for 'http' mode")
if self.mode not in ("ray", "http"):
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
@dataclass
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
pass
@dataclass
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for IPC weight transfer backend.
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
it is unpickled into ``ipc_handles`` during ``__post_init__``.
"""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
ipc_handles_pickled: str | None = None
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
def __post_init__(self):
if self.ipc_handles_pickled is not None:
if self.ipc_handles is not None:
raise ValueError(
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
)
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
self.ipc_handles_pickled = None
if self.ipc_handles is None:
raise ValueError(
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
)
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
f"`dtype_names` should be of the same size as `names`: "
f"got {len(self.dtype_names)} and {len(self.names)}"
)
if len(self.shapes) != num_params:
raise ValueError(
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
if len(self.ipc_handles) != num_params:
raise ValueError(
f"`ipc_handles` should be of the same size as `names`: "
f"got {len(self.ipc_handles)} and {len(self.names)}"
)
class IPCWeightTransferEngine(
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
):
"""
Weight transfer engine using CUDA IPC for communication between trainer and workers.
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
to all inference workers in a process group. IPC handles are used to share
memory between processes on the same node.
"""
# Define backend-specific dataclass types
init_info_cls = IPCWeightTransferInitInfo
update_info_cls = IPCWeightTransferUpdateInfo
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the IPC weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
super().__init__(config, parallel_config)
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
No initialization needed for IPC backend.
Args:
init_info: IPC initialization info (empty)
"""
pass
def receive_weights(
self,
update_info: IPCWeightTransferUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer via CUDA IPC handles.
Args:
update_info: IPC update info containing parameter names, dtypes, shapes,
and IPC handles. Each IPC handle is a mapping between physical
GPU UUID and the IPC handle tuple (func, args).
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
assert update_info.ipc_handles is not None
weights = []
for name, _dtype_name, _shape, ipc_handle in zip(
update_info.names,
update_info.dtype_names,
update_info.shapes,
update_info.ipc_handles,
):
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
handle = ipc_handle[physical_gpu_id]
func, args = handle
list_args = list(args) # type: ignore
# Index 6 is the device_index parameter in torch's
# IPC handle tuple (rebuild_cuda_tensor). Update it
# to the current device since the logical index can
# differ between sender and receiver.
list_args[6] = device_index
weight = func(*list_args) # type: ignore
weights.append((name, weight))
load_weights(weights)
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
"""
pass
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
) -> None:
"""
Send weights from trainer to inference workers via CUDA IPC.
Supports two modes:
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
Tensors should be on the same GPU as the inference workers.
trainer_args: Dictionary containing IPC-specific arguments.
Should contain keys from IPCTrainerSendWeightsArgs:
- mode: 'ray' or 'http'
- llm_handle: Ray ObjectRef (for 'ray' mode)
- url: Base URL string (for 'http' mode)
Example (Ray mode):
>>> from vllm.distributed.weight_transfer.ipc_engine import (
... IPCWeightTransferEngine,
... IPCTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
Example (HTTP mode):
>>> args = IPCTrainerSendWeightsArgs(
... mode="http", url="http://localhost:8000"
... )
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
"""
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = IPCTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
# Get physical GPU UUID
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid)
# Collect weight metadata and create IPC handles
names = []
dtype_names = []
shapes = []
ipc_handles = []
for name, tensor in iterator:
names.append(name)
dtype_names.append(str(tensor.dtype).split(".")[-1])
shapes.append(list(tensor.shape))
# Create IPC handle for this weight tensor
# The tensor must remain in memory for IPC to work
weight = tensor.detach().contiguous()
ipc_handle = reduce_tensor(weight)
ipc_handles.append({gpu_uuid: ipc_handle})
# Send weights based on mode
if args.mode == "ray":
# Ray mode: send via Ray RPC
import ray
update_info = asdict(
IPCWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
ipc_handles=ipc_handles,
)
)
ray.get(
args.llm_handle.update_weights.remote(dict(update_info=update_info))
)
elif args.mode == "http":
# HTTP mode: send via HTTP POST with pickled handles
# Pickle and base64 encode IPC handles for HTTP transmission
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
"utf-8"
)
url = f"{args.url}/update_weights"
payload = {
"update_info": {
"names": names,
"dtype_names": dtype_names,
"shapes": shapes,
"ipc_handles_pickled": pickled_handles,
}
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()

View File

@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
world_size: int
@dataclass
class NCCLTrainerSendWeightsArgs:
"""Arguments for NCCL trainer_send_weights method."""
group: Any
"""Process group (PyNcclCommunicator) for NCCL communication."""
src: int = 0
"""Source rank (default 0, trainer is typically rank 0)."""
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
"""Optional function to apply to each (name, tensor) pair before broadcasting.
If None, extracts just the tensor."""
packed: bool = False
"""Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
stream: torch.cuda.Stream | None = None
"""CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer. Default is 1GB.
"""Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int = 0,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
| None = None,
packed: bool = False,
stream: torch.cuda.Stream | None = None,
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
group: Process group (PyNcclCommunicator)
src: Source rank (default 0, trainer is typically rank 0)
post_iter_func: Optional function to apply to each (name, tensor) pair
before broadcasting. If None, extracts just the tensor.
packed: Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before
broadcasting to reduce NCCL communication overhead.
stream: CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer.
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo.
packed_num_buffers: Number of buffers for double/triple buffering.
Must match the value used in NCCLWeightTransferUpdateInfo.
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
NCCL-specific arguments. If a dict, should contain keys from
NCCLTrainerSendWeightsArgs.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... NCCLTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> NCCLWeightTransferEngine.trainer_send_weights(
... param_iter, group, packed=True
... )
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
"""
if post_iter_func is None:
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = NCCLTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
if args.post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
else:
post_iter_func = args.post_iter_func
if packed:
if args.packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
packed_broadcast_producer(
iterator=iterator,
group=group,
src=src,
group=args.group,
src=args.src,
post_iter_func=post_iter_func,
buffer_size_bytes=packed_buffer_size_bytes,
num_buffers=packed_num_buffers,
buffer_size_bytes=args.packed_buffer_size_bytes,
num_buffers=args.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
group.broadcast(
tensor, src=src, stream=stream or torch.cuda.current_stream()
args.group.broadcast(
tensor,
src=args.src,
stream=args.stream or torch.cuda.current_stream(),
)
@staticmethod