Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,14 +3,13 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
self.handle_cache._cache.clear()
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
# allow_nvlink_for_low_latency_mode=True,
|
||||
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KiB = 1024
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
},
|
||||
}
|
||||
|
||||
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
|
||||
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
|
||||
# while custom_AR wins for mid-range sizes.
|
||||
#
|
||||
# Benchmark results (8 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
|
||||
# 32K - 64K: custom_AR wins
|
||||
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
|
||||
#
|
||||
# Benchmark results (4 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
|
||||
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
|
||||
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
|
||||
#
|
||||
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
|
||||
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
|
||||
"custom_ar_preferred_ranges": {
|
||||
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
|
||||
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
|
||||
},
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Determine if NCCL symmetric memory allreduce should be used.
|
||||
|
||||
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
|
||||
- Small tensors (≤16K): Lower latency than custom_AR
|
||||
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
|
||||
|
||||
Custom_AR is preferred for mid-range sizes where its P2P approach
|
||||
has lower overhead than the symm_mem copy-in/copy-out pattern.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
|
||||
tensor_size = input_tensor.nbytes
|
||||
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
|
||||
world_size
|
||||
)
|
||||
|
||||
if custom_ar_range is not None:
|
||||
lower_bound, upper_bound = custom_ar_range
|
||||
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
|
||||
# Use custom_AR (not symm_mem) for mid-range sizes
|
||||
return tensor_size <= lower_bound or tensor_size >= upper_bound
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
|
||||
@@ -30,8 +30,9 @@ class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
self.cpu_group = cpu_group
|
||||
self.tcp_store_group = tcp_store_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -48,12 +49,17 @@ class All2AllManagerBase:
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
if tcp_store_group is None:
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
else:
|
||||
self.internode = not all(
|
||||
in_the_same_node_as(tcp_store_group, source_rank=0)
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
@@ -122,17 +128,36 @@ class DeviceCommunicatorBase:
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
# Check if this is a stateless process group
|
||||
from torch.distributed.distributed_c10d import _world
|
||||
|
||||
is_stateless = _world.pg_map.get(cpu_group, None) is None
|
||||
|
||||
if is_stateless:
|
||||
# For stateless groups, we can't use torch.distributed methods
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
assert global_ranks is not None
|
||||
assert global_world_size is not None
|
||||
self.ranks = global_ranks
|
||||
self.global_rank = self.ranks[self.rank]
|
||||
self.global_world_size = global_world_size
|
||||
self.rank_in_group = self.rank
|
||||
else:
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
torch.distributed.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
and self._all_group_ranks_share_shm_group_name()
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
|
||||
logger.info(
|
||||
"CPU SHM communicator disabled for group %s: ranks do not share "
|
||||
"the same SHM group name, falling back to torch.distributed.",
|
||||
unique_name,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def _all_group_ranks_share_shm_group_name(self) -> bool:
|
||||
"""
|
||||
CPUSHM requires all ranks in this group to agree on one SHM group name.
|
||||
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
|
||||
"""
|
||||
local_name = _CPUSHMDistributed.make_group_name(self)
|
||||
names: list[str] = [""] * self.world_size
|
||||
torch.distributed.all_gather_object(
|
||||
names,
|
||||
local_name,
|
||||
group=self.device_group,
|
||||
)
|
||||
return len(set(names)) == 1
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
self.communicator = communicator
|
||||
|
||||
self.group_name = self.make_group_name(communicator)
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
@staticmethod
|
||||
def make_group_name(communicator: CpuCommunicator) -> str:
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
group_ranks = [str(rank) for rank in communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import StatelessProcessGroup
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
tcp_store_group: StatelessProcessGroup | None = None,
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
super().__init__(
|
||||
cpu_group,
|
||||
device,
|
||||
device_group,
|
||||
unique_name,
|
||||
global_ranks,
|
||||
global_world_size,
|
||||
)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
|
||||
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = NaiveAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = AgRsAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return out
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.use_vllm_comm:
|
||||
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
|
||||
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
# Perform reduce-scatter operation
|
||||
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.send(tensor, dst)
|
||||
# else:
|
||||
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.send(tensor, self.ranks[dst], self.device_group)
|
||||
else:
|
||||
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
# pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.recv(tensor, src)
|
||||
# else:
|
||||
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.recv(tensor, self.ranks[src], self.device_group)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.broadcast(tensor, src)
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
self.all2all_manager = None # type: ignore[assignment]
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
# return self.all2all_manager.dispatch(
|
||||
# hidden_states,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# is_sequence_parallel,
|
||||
# extra_tensors=extra_tensors,
|
||||
# )
|
||||
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, extra_residual, router_logits)
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.group_start()
|
||||
for op in p2p_ops:
|
||||
if op.op is torch.distributed.isend:
|
||||
self.send(op.tensor, op.group_peer, stream)
|
||||
elif op.op is torch.distributed.irecv:
|
||||
self.recv(op.tensor, op.group_peer, stream)
|
||||
|
||||
self.group_end()
|
||||
|
||||
0
vllm/distributed/elastic_ep/__init__.py
Normal file
0
vllm/distributed/elastic_ep/__init__.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import gc
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import P2POp
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.wrapper import reset_compile_wrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.elastic_ep.standby_state import (
|
||||
create_standby_groups,
|
||||
get_standby_dp_group,
|
||||
get_standby_ep_group,
|
||||
pop_standby_groups,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
_replace_active_groups,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def batch_transfer_weights(
|
||||
model: nn.Module,
|
||||
is_sender: bool,
|
||||
peer_rank: int,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
) -> None:
|
||||
device_comm = dp_group.device_communicator
|
||||
if device_comm is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
expert_weights_set = set()
|
||||
for weight_group in expert_weights:
|
||||
for weight in weight_group:
|
||||
expert_weights_set.add(weight.data_ptr())
|
||||
|
||||
state_dict = model.state_dict()
|
||||
all_params = []
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith("expert_map"):
|
||||
continue
|
||||
if param.data_ptr() not in expert_weights_set:
|
||||
all_params.append(param.data)
|
||||
|
||||
assert len(all_params) > 0
|
||||
p2p_ops = []
|
||||
for param in all_params:
|
||||
op = object.__new__(P2POp)
|
||||
if is_sender:
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = param
|
||||
else:
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = param
|
||||
op.group_peer = peer_rank
|
||||
p2p_ops.append(op)
|
||||
device_comm.batch_isend_irecv(p2p_ops)
|
||||
|
||||
|
||||
def broadcast_expert_mapping(
|
||||
physical_to_logical: torch.Tensor | None,
|
||||
num_local_physical_experts: int | None,
|
||||
num_logical_experts: int | None,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
device: torch.device,
|
||||
src_rank: int = 0,
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
if dp_group.rank_in_group == src_rank:
|
||||
assert physical_to_logical is not None
|
||||
assert num_local_physical_experts is not None
|
||||
assert num_logical_experts is not None
|
||||
assert physical_to_logical.dtype == torch.int64
|
||||
shape_tensor = torch.tensor(
|
||||
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
|
||||
)
|
||||
metadata_tensor = torch.tensor(
|
||||
[num_local_physical_experts, num_logical_experts],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
|
||||
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
|
||||
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
|
||||
|
||||
if dp_group.rank_in_group != src_rank:
|
||||
assert device is not None
|
||||
physical_to_logical = torch.empty(
|
||||
tuple(shape_tensor.tolist()),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert physical_to_logical is not None
|
||||
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
|
||||
num_local_physical_experts = int(metadata_tensor[0].item())
|
||||
num_logical_experts = int(metadata_tensor[1].item())
|
||||
|
||||
return physical_to_logical, num_local_physical_experts, num_logical_experts
|
||||
|
||||
|
||||
class ElasticEPScalingExecutor:
|
||||
def __init__(self, worker):
|
||||
self.worker_ref = weakref.ref(worker)
|
||||
self.reconfig_request = None
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
worker = self.worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def execute(self, execute_method: str, *args, **kwargs):
|
||||
method = getattr(self, execute_method, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Unknown execute method: {execute_method}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def create_standby_groups(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.reconfig_request = reconfig_request
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
world_size = self.worker.vllm_config.parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_dp_size
|
||||
updated_config = copy.copy(self.worker.vllm_config)
|
||||
updated_config.parallel_config = copy.deepcopy(
|
||||
self.worker.vllm_config.parallel_config
|
||||
)
|
||||
updated_config.parallel_config.data_parallel_size = new_dp_size
|
||||
with set_current_vllm_config(updated_config):
|
||||
create_standby_groups(
|
||||
new_dp_size=new_dp_size,
|
||||
new_world_size_across_dp=new_world_size_across_dp,
|
||||
master_ip=reconfig_request.new_data_parallel_master_ip,
|
||||
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
|
||||
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
|
||||
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
|
||||
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
assert standby_ep_group is not None
|
||||
if standby_ep_group.rank == 0:
|
||||
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
|
||||
|
||||
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
# Broadcast old_dp_size to all workers in standby group
|
||||
if standby_dp_group.rank_in_group < old_dp_size:
|
||||
old_dp_size_tensor = torch.tensor(
|
||||
[old_dp_size], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
else:
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
|
||||
old_dp_size_tensor, 0
|
||||
)
|
||||
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Sender-receiver pairing: the first new_workers % old_dp_size
|
||||
# senders get (k+1) contiguous receivers, the rest get k
|
||||
# receivers.
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if dp_rank < remainder:
|
||||
recv_begin = dp_rank * (num_dst_per_sender + 1)
|
||||
recv_end = recv_begin + num_dst_per_sender + 1
|
||||
else:
|
||||
recv_begin = (
|
||||
remainder * (num_dst_per_sender + 1)
|
||||
+ (dp_rank - remainder) * num_dst_per_sender
|
||||
)
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
|
||||
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
for new_worker_rank in sorted(ranks_to_send):
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=True,
|
||||
peer_rank=new_worker_rank,
|
||||
dp_group=standby_dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def broadcast_expert_mapping(self) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_physical_experts = physical_to_logical.shape[1]
|
||||
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=physical_to_logical,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_logical_experts=num_logical_experts,
|
||||
dp_group=standby_dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
|
||||
def switch_and_remove(self) -> None:
|
||||
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
|
||||
|
||||
def switch_and_prepare(self) -> None:
|
||||
old_dp_size = get_dp_group().world_size
|
||||
old_ep_size = get_ep_group().world_size
|
||||
|
||||
_replace_active_groups(**pop_standby_groups())
|
||||
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
reconfig_request = self.reconfig_request
|
||||
assert reconfig_request is not None
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
new_ep_size = get_ep_group().world_size
|
||||
|
||||
parallel_config.data_parallel_size = new_dp_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
# Reconfigure MoE modules with new EP size
|
||||
moe_modules = [
|
||||
module
|
||||
for module in self.worker.model_runner.model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
|
||||
# Update EPLB state
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
|
||||
num_physical_experts = num_local_experts * new_ep_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
old_physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_moe_layers = old_physical_to_logical.shape[0]
|
||||
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
|
||||
if new_dp_size > old_dp_size:
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=old_physical_to_logical.dtype,
|
||||
device=old_physical_to_logical.device,
|
||||
)
|
||||
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
|
||||
old_physical_to_logical
|
||||
)
|
||||
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
|
||||
|
||||
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
|
||||
pad_size = num_physical_experts - old_num_physical_experts
|
||||
if new_dp_size > old_dp_size:
|
||||
assert pad_size > 0
|
||||
expanded_expert_load_pass = F.pad(
|
||||
eplb_model_state.expert_load_pass, (0, pad_size), value=0
|
||||
)
|
||||
expanded_expert_load_window = F.pad(
|
||||
eplb_model_state.expert_load_window, (0, pad_size), value=0
|
||||
)
|
||||
eplb_model_state.expert_load_pass = expanded_expert_load_pass
|
||||
eplb_model_state.expert_load_window = expanded_expert_load_window
|
||||
eplb_state.num_valid_physical_experts = old_num_physical_experts
|
||||
else:
|
||||
assert pad_size < 0
|
||||
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
|
||||
:, :num_physical_experts
|
||||
]
|
||||
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, :num_physical_experts
|
||||
]
|
||||
eplb_state.num_valid_physical_experts = num_physical_experts
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
model.expert_weights = []
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
model.set_eplb_state(
|
||||
eplb_model_state.expert_load_pass,
|
||||
eplb_model_state.logical_to_physical_map,
|
||||
eplb_model_state.logical_replica_count,
|
||||
)
|
||||
model.update_physical_experts_metadata(
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_experts,
|
||||
)
|
||||
# Force re-creation of the modular kernel (and all2all manager)
|
||||
# for the new EP size by resetting quant_method to base
|
||||
for module in moe_modules:
|
||||
if hasattr(module.quant_method, "old_quant_method"):
|
||||
module.quant_method = module.quant_method.old_quant_method
|
||||
module.runner = module._init_runner()
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
if (
|
||||
self.worker.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
# NOTE(yongji): when using stock torch.compile,
|
||||
# torch.compile is triggered during GPUModelRunner's load_model()
|
||||
# TODO(yongji):check do we need to re-trigger torch.compile here?
|
||||
# any changes to the tensor shapes in execution should already
|
||||
# be handled internally by torch.compile.
|
||||
backend = self.worker.vllm_config.compilation_config.init_backend(
|
||||
self.worker.vllm_config
|
||||
)
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
# release all previously captured CUDA graphs
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
multi_block_table = self.worker.model_runner.input_batch.block_table
|
||||
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for bt in multi_block_table.block_tables:
|
||||
saved_block_tables.append(
|
||||
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
|
||||
)
|
||||
multi_block_table.clear()
|
||||
|
||||
# reset the compile wrapper
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
unlock_workspace()
|
||||
self.worker.compile_or_warm_up_model()
|
||||
lock_workspace()
|
||||
|
||||
for bt, (saved_gpu, saved_cpu) in zip(
|
||||
multi_block_table.block_tables, saved_block_tables
|
||||
):
|
||||
bt.block_table.gpu.copy_(saved_gpu)
|
||||
bt.block_table.cpu.copy_(saved_cpu)
|
||||
|
||||
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding...")
|
||||
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
is_async_enabled = eplb_state.is_async
|
||||
eplb_state.is_async = False
|
||||
if new_dp_size is None:
|
||||
eplb_state.rearrange()
|
||||
else:
|
||||
# scale down
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
|
||||
eplb_state.rearrange(rank_mapping=rank_mapping)
|
||||
# NOTE(yongji): check whether we need to synchronize here
|
||||
torch.cuda.synchronize()
|
||||
# reset expert_rearrangement_step to ensure all ranks are synchronized
|
||||
eplb_state.expert_rearrangement_step = 0
|
||||
eplb_state.num_valid_physical_experts = (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
)
|
||||
eplb_state.is_async = is_async_enabled
|
||||
self.worker.model_runner.eep_eplb_suppressed = False
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed")
|
||||
|
||||
def receive_weights(self) -> None:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
new_dp_size = dp_group.world_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Receive old_dp_size broadcasted during transfer_weights
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
|
||||
old_dp_size = int(old_dp_size_tensor[0].item())
|
||||
|
||||
# Calculate which existing worker will send to this new worker
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
new_worker_idx = dp_rank - old_dp_size
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if new_worker_idx < remainder * (num_dst_per_sender + 1):
|
||||
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
|
||||
else:
|
||||
sender_rank = (
|
||||
remainder
|
||||
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
|
||||
// num_dst_per_sender
|
||||
)
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=False,
|
||||
peer_rank=sender_rank,
|
||||
dp_group=dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
physical_to_logical, num_local_physical_experts, num_logical_experts = (
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=None,
|
||||
num_local_physical_experts=None,
|
||||
num_logical_experts=None,
|
||||
dp_group=dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
)
|
||||
num_moe_layers = physical_to_logical.shape[0]
|
||||
new_dp_size = get_dp_group().world_size
|
||||
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_physical_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=physical_to_logical.dtype,
|
||||
device=physical_to_logical.device,
|
||||
)
|
||||
old_num_physical_experts = physical_to_logical.shape[1]
|
||||
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
|
||||
return (
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
)
|
||||
|
||||
def prepare_new_worker(self) -> None:
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
|
||||
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
@@ -0,0 +1,563 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import time
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import (
|
||||
sched_yield,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import (
|
||||
EEPNotificationType,
|
||||
ReconfigureDistributedRequest,
|
||||
ReconfigureRankType,
|
||||
)
|
||||
from vllm.v1.engine.core import DPEngineCoreProc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WorkerType = Literal["existing", "new", "removing"]
|
||||
|
||||
|
||||
class ScaleUpExistingEngineState(enum.IntEnum):
|
||||
WAIT_NEW_CORE_ENGINES_INIT = 0
|
||||
CREATE_STANDBY_GROUPS = 1
|
||||
TRANSFER_EXPERT_MAPPING = 2
|
||||
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
|
||||
TRANSFER_WEIGHTS = 4
|
||||
SYNC_KV_CACHE_MEMORY_SIZE = 5
|
||||
SWITCH_AND_PREPARE = 6
|
||||
EPLB_RESHUFFLE = 7
|
||||
COMPLETE = 8
|
||||
|
||||
|
||||
class ScaleUpNewEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class ScaleDownRemainingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
SWITCH_AND_PREPARE = 2
|
||||
COMPLETE = 3
|
||||
|
||||
|
||||
class ScaleDownRemovingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class _BarrierTimeoutError(RuntimeError):
|
||||
"""
|
||||
Exception raised for timeout
|
||||
in the first stage of our two-staged
|
||||
TCPStore based barrier to synchronize the
|
||||
execution of all engines in the DP group.
|
||||
"""
|
||||
|
||||
|
||||
class ElasticEPScalingState:
|
||||
def __init__(
|
||||
self,
|
||||
model_executor: "Executor",
|
||||
engine_core: "DPEngineCoreProc",
|
||||
vllm_config: "VllmConfig",
|
||||
new_parallel_config: ParallelConfig,
|
||||
worker_type: WorkerType,
|
||||
scale_type: Literal["scale_up", "scale_down"],
|
||||
reconfig_request: ReconfigureDistributedRequest | None = None,
|
||||
):
|
||||
self.model_executor_ref = weakref.ref(model_executor)
|
||||
self.engine_core_ref = weakref.ref(engine_core)
|
||||
self.vllm_config = vllm_config
|
||||
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
|
||||
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
|
||||
self.new_parallel_config: ParallelConfig = new_parallel_config
|
||||
self.new_dp_group: torch.distributed.ProcessGroup | None = (
|
||||
self.engine_core.dp_group if worker_type == "new" else None
|
||||
)
|
||||
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
|
||||
self.worker_type = worker_type
|
||||
self.scale_type = scale_type
|
||||
self.reconfig_request = reconfig_request
|
||||
|
||||
if scale_type == "scale_up":
|
||||
self.state = (
|
||||
ScaleUpNewEngineState.PREPARE
|
||||
if worker_type == "new"
|
||||
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
)
|
||||
else:
|
||||
self.state = (
|
||||
ScaleDownRemovingEngineState.PREPARE
|
||||
if worker_type == "removing"
|
||||
else ScaleDownRemainingEngineState.PREPARE
|
||||
)
|
||||
|
||||
@property
|
||||
def model_executor(self) -> "Executor":
|
||||
model_executor = self.model_executor_ref()
|
||||
if model_executor is None:
|
||||
raise RuntimeError("Model executor has been garbage collected")
|
||||
return model_executor
|
||||
|
||||
@property
|
||||
def engine_core(self) -> "DPEngineCoreProc":
|
||||
engine_core = self.engine_core_ref()
|
||||
if engine_core is None:
|
||||
raise RuntimeError("Engine core has been garbage collected")
|
||||
return engine_core
|
||||
|
||||
def progress(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self._progress_new_engine()
|
||||
if self.worker_type == "new"
|
||||
else self._progress_existing_engine()
|
||||
)
|
||||
return (
|
||||
self._progress_removing_engine()
|
||||
if self.worker_type == "removing"
|
||||
else self._progress_remaining_engine()
|
||||
)
|
||||
|
||||
def _execute_tcp_store_barrier(
|
||||
self, dp_store, group_rank, group_size, barrier_id, timeout=None
|
||||
):
|
||||
arrival_key = f"arrival_{barrier_id}_{group_rank}"
|
||||
dp_store.set(arrival_key, b"1")
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < group_size:
|
||||
if (
|
||||
timeout is not None
|
||||
and time.time() - start_time > timeout.total_seconds()
|
||||
):
|
||||
raise _BarrierTimeoutError(
|
||||
f"Barrier timed out after {timeout.total_seconds()} seconds"
|
||||
)
|
||||
|
||||
for i in range(group_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
present = dp_store.check([key])
|
||||
if present:
|
||||
processes_arrived.add(i)
|
||||
|
||||
if len(processes_arrived) < group_size:
|
||||
sched_yield()
|
||||
|
||||
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
|
||||
"""
|
||||
Execute a two-staged barrier to synchronize all engines in the DP group.
|
||||
|
||||
Some DP EngineCores may receive the reconfiguration notifications
|
||||
later than others, and already proceed to engine step (model forward)
|
||||
in the busy loop.
|
||||
In this case, EngineCores that already proceed to reconfiguration
|
||||
should skip reconfiguration and execute model forward for one more
|
||||
step, so in the next step, all EngineCores will be synchronized.
|
||||
We use a two-staged barrier to achieve this. The first time each
|
||||
EngineCore executes the barrier, if a timeout is reached before the
|
||||
barrier completes, that means some EngineCores have already entered
|
||||
engine step. The EngineCores that timed out will then proceed to
|
||||
engine step, and will synchronize with the other EngineCores in the
|
||||
next step with a barrier without timeout.
|
||||
"""
|
||||
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
|
||||
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
|
||||
assert dp_group is not None
|
||||
|
||||
group_rank = dp_group.rank()
|
||||
group_size = dp_group.size()
|
||||
barrier_id = f"eep_barrier_{barrier_name}"
|
||||
sync_key = f"{barrier_id}_sync"
|
||||
|
||||
# TODO(yongji): figure out appropriate timeout for the barrier
|
||||
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
|
||||
|
||||
try:
|
||||
self._execute_tcp_store_barrier(
|
||||
dp_store, group_rank, group_size, barrier_id, timeout=timeout
|
||||
)
|
||||
torch.distributed.barrier(dp_group)
|
||||
if group_rank == 0:
|
||||
dp_store.delete_key(sync_key)
|
||||
for i in range(group_size):
|
||||
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
return True
|
||||
except _BarrierTimeoutError as e:
|
||||
if timeout is None:
|
||||
raise RuntimeError("Unexpected timeout encountered") from e
|
||||
dp_store.compare_set(sync_key, "", b"1")
|
||||
return False
|
||||
|
||||
def _progress_existing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
|
||||
# NOTE(yongji): wait for all existing workers to receive the request
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="create_standby_groups"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._create_standby_groups()
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
|
||||
self._transfer_expert_mapping()
|
||||
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="transfer_weights"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._transfer_weights()
|
||||
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
|
||||
self._sync_kv_cache_memory_size()
|
||||
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
|
||||
self._switch_and_prepare()
|
||||
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
|
||||
assert self.new_dp_group is not None
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.new_dp_group.rank() == 0:
|
||||
self.new_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpExistingEngineState.COMPLETE
|
||||
self._update_parallel_config()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_new_engine(self) -> bool:
|
||||
state = self.state
|
||||
assert self.new_dp_group is not None
|
||||
|
||||
if state == ScaleUpNewEngineState.PREPARE:
|
||||
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.new_dp_group,
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.new_dp_group.rank() > 0
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_remaining_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemainingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
|
||||
# NOTE(yongji): currently, after EPLB reshuffle
|
||||
# that redistributes experts to remaining workers, workers
|
||||
# to be removed will immediately initiate shutdown;
|
||||
# existing workers can no longer execute forward steps using
|
||||
# the old setup. In the future, we may keep
|
||||
# the removing workers alive a bit longer,
|
||||
# e.g., to drain in-batch requests.
|
||||
self._create_standby_groups()
|
||||
self._switch_and_prepare()
|
||||
self._update_parallel_config()
|
||||
self.state = ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_removing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemovingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.old_dp_group.rank() > 0
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self._switch_and_remove()
|
||||
self.state = ScaleDownRemovingEngineState.COMPLETE
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE
|
||||
)
|
||||
self.engine_core.shutdown()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def handle_notification(self, notification_type: EEPNotificationType):
|
||||
assert self.worker_type != "new"
|
||||
if (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
|
||||
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
|
||||
elif (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
and self.state
|
||||
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self.state == ScaleUpNewEngineState.COMPLETE
|
||||
if self.worker_type == "new"
|
||||
else self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
)
|
||||
return (
|
||||
self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
if self.worker_type == "removing"
|
||||
else self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
)
|
||||
|
||||
def _create_standby_groups(self):
|
||||
self.new_dp_group, self.new_dp_store = (
|
||||
self.new_parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Created standby communication groups")
|
||||
|
||||
def _transfer_weights(self):
|
||||
assert self.reconfig_request is not None
|
||||
old_dp_size = self.old_dp_group.size()
|
||||
new_dp_size = self.reconfig_request.new_data_parallel_size
|
||||
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Transferred weights to new workers")
|
||||
|
||||
def _transfer_expert_mapping(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("broadcast_expert_mapping",)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
|
||||
|
||||
def _sync_kv_cache_memory_size(self):
|
||||
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
|
||||
assert self.new_dp_group is not None
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.new_dp_group,
|
||||
self.engine_core.available_gpu_memory_for_kv_cache,
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
|
||||
|
||||
def _switch_and_prepare(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_prepare",)
|
||||
)
|
||||
old_dp_group = self.old_dp_group
|
||||
stateless_destroy_torch_distributed_process_group(old_dp_group)
|
||||
assert self.new_dp_group is not None
|
||||
new_dp_group = self.new_dp_group
|
||||
self.engine_core.dp_group = new_dp_group
|
||||
self.engine_core.dp_rank = new_dp_group.rank()
|
||||
self.engine_core.dp_store = self.new_dp_store
|
||||
engines_running = int(self.engine_core.engines_running)
|
||||
current_wave = self.engine_core.current_wave
|
||||
step_counter = self.engine_core.step_counter
|
||||
tensor = torch.tensor(
|
||||
[engines_running, current_wave, step_counter],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
if new_dp_group.rank() == 0:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.RECONFIGURE_FINISHED
|
||||
)
|
||||
logger.info("[Elastic EP] Switched to new setup")
|
||||
|
||||
def _eplb_reshuffle(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
|
||||
)
|
||||
assert self.new_dp_group is not None
|
||||
if self.new_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _eplb_reshuffle_before_scale_down(self):
|
||||
assert self.reconfig_request is not None
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute",
|
||||
args=(
|
||||
"perform_eplb_reshuffle",
|
||||
self.reconfig_request.new_data_parallel_size,
|
||||
),
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _switch_and_remove(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_remove",)
|
||||
)
|
||||
|
||||
def _update_parallel_config(self):
|
||||
assert self.reconfig_request is not None
|
||||
reconfig_request = self.reconfig_request
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
reconfig_request.new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = (
|
||||
reconfig_request.new_stateless_dp_group_port_list
|
||||
)
|
||||
parallel_config._stateless_ep_group_port_list = (
|
||||
reconfig_request.new_stateless_ep_group_port_list
|
||||
)
|
||||
parallel_config._stateless_eplb_group_port_list = (
|
||||
reconfig_request.new_stateless_eplb_group_port_list
|
||||
)
|
||||
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
_init_stateless_group,
|
||||
_node_count,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_WORLD_NODE_COUNT: int | None = None
|
||||
_STANDBY_DP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_DP
|
||||
|
||||
|
||||
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EP
|
||||
|
||||
|
||||
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EPLB
|
||||
|
||||
|
||||
def get_standby_world_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_WORLD
|
||||
|
||||
|
||||
def create_standby_groups(
|
||||
new_dp_size: int,
|
||||
new_world_size_across_dp: int,
|
||||
master_ip: str,
|
||||
world_group_ports: list[list[int]],
|
||||
dp_group_ports: list[list[int]],
|
||||
ep_group_ports: list[list[int]],
|
||||
eplb_group_ports: list[list[int]] | None = None,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
|
||||
world_group = get_world_group()
|
||||
assert isinstance(world_group, StatelessGroupCoordinator)
|
||||
backend = backend or world_group.backend
|
||||
|
||||
standby_world_ranks = [list(range(new_world_size_across_dp))]
|
||||
_STANDBY_WORLD = _init_stateless_group(
|
||||
standby_world_ranks,
|
||||
"world",
|
||||
world_group_ports,
|
||||
master_ip,
|
||||
backend,
|
||||
use_device_communicator=False,
|
||||
)
|
||||
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
|
||||
|
||||
tp_size = get_tp_group().world_size
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
all_ranks = torch.arange(new_world_size_across_dp).reshape(
|
||||
-1, new_dp_size, pp_size, tp_size
|
||||
)
|
||||
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
|
||||
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
|
||||
_STANDBY_DP = _init_stateless_group(
|
||||
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
standby_ep_ranks = (
|
||||
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
|
||||
)
|
||||
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
|
||||
_STANDBY_EP = _init_stateless_group(
|
||||
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
if eplb_group_ports is not None:
|
||||
_STANDBY_EPLB = _init_stateless_group(
|
||||
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
|
||||
def pop_standby_groups() -> dict:
|
||||
"""Return all standby groups and clear the standby state."""
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
result = dict(
|
||||
world=_STANDBY_WORLD,
|
||||
dp=_STANDBY_DP,
|
||||
ep=_STANDBY_EP,
|
||||
eplb=_STANDBY_EPLB,
|
||||
node_count=_STANDBY_WORLD_NODE_COUNT,
|
||||
)
|
||||
_STANDBY_WORLD = None
|
||||
_STANDBY_WORLD_NODE_COUNT = None
|
||||
_STANDBY_DP = None
|
||||
_STANDBY_EP = None
|
||||
_STANDBY_EPLB = None
|
||||
return result
|
||||
@@ -24,7 +24,6 @@ logger = init_logger(__name__)
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
eplb_group = get_eplb_group().device_group
|
||||
@@ -45,7 +44,6 @@ def start_async_worker(
|
||||
eplb_group=eplb_group,
|
||||
cuda_stream=cuda_stream,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
|
||||
eplb_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
@@ -159,7 +160,7 @@ class EplbModelState:
|
||||
|
||||
NOTE: The expert_load_view now records load for all physical experts
|
||||
rather than just local experts. This ensures consistent load statistics
|
||||
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
|
||||
across different dispatch methods (naive all-to-all, DeepEP).
|
||||
The recorded load will be multiplied by dp_size when using naive all-to-all
|
||||
due to each DP rank contributing the same token set to the calculation.
|
||||
See:
|
||||
@@ -302,6 +303,14 @@ class EplbState:
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
self.num_valid_physical_experts: int = 0
|
||||
"""
|
||||
Number of valid physical experts.
|
||||
This is the number of physical experts that are
|
||||
actually mapped to logical experts. In elastic EP,
|
||||
newly started EP ranks may not have physical experts
|
||||
mapped yet.
|
||||
"""
|
||||
if self.device.type == "cuda":
|
||||
self.cuda_device_index = self.device.index
|
||||
if self.cuda_device_index is None and torch.cuda.is_available():
|
||||
@@ -367,9 +376,6 @@ class EplbState:
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
global_expert_load: torch.Tensor | None = None,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
@@ -462,75 +468,15 @@ class EplbState:
|
||||
)
|
||||
self.expert_rearrangement_step_interval = eplb_step_interval
|
||||
|
||||
# Set the policy based on the selected eplb algorithm type.
|
||||
policy_type = self.parallel_config.eplb_config.policy
|
||||
self.policy = EPLB_POLICIES[policy_type]
|
||||
logger.debug("Selected EPLB policy: %s", policy_type)
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
else:
|
||||
new_physical_to_logical_map = None
|
||||
|
||||
new_logical_to_physical_map = None
|
||||
|
||||
new_logical_replica_count = None
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
self.expert_rearrangement_step = 0
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
|
||||
|
||||
@@ -561,11 +507,12 @@ class EplbState:
|
||||
recv_dst_rows=np.array([]),
|
||||
),
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
new_logical_to_physical_map=new_logical_to_physical_map,
|
||||
new_logical_replica_count=new_logical_replica_count,
|
||||
new_physical_to_logical_map=None,
|
||||
new_logical_to_physical_map=None,
|
||||
new_logical_replica_count=None,
|
||||
)
|
||||
self.model_states[model_config.compute_hash()] = model_state
|
||||
self.num_valid_physical_experts = model.num_physical_experts
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -696,8 +643,6 @@ class EplbState:
|
||||
def rearrange(
|
||||
self,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_loads: list[torch.Tensor] | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
@@ -707,12 +652,6 @@ class EplbState:
|
||||
is_profile (bool): If `True`, perform a dummy rearrangement.
|
||||
This is used in `profile_run` to reserve enough memory,
|
||||
no memory movement will be performed. Default is False.
|
||||
execute_shuffle (bool): If `True`, execute the shuffle
|
||||
in elastic expert parallel (EEP). Default is True.
|
||||
global_expert_loads (list[torch.Tensor] | None): The global expert
|
||||
loads when scaling is done in EEP.
|
||||
List of expert loads for the main and drafter
|
||||
(when spec decode is used) models.
|
||||
rank_mapping (dict[int, int] | None): The rank mapping
|
||||
when scaling is done in EEP.
|
||||
"""
|
||||
@@ -734,67 +673,34 @@ class EplbState:
|
||||
"(profile)" if is_profile else "",
|
||||
)
|
||||
|
||||
if global_expert_loads is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
if not execute_shuffle:
|
||||
num_models = torch.tensor(
|
||||
[len(self.model_states)], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_models, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
for eplb_model_state in self.model_states.values():
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
dtype=eplb_model_state.expert_load_window.dtype,
|
||||
device=eplb_model_state.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
|
||||
.expand_as(eplb_model_state.expert_load_window)
|
||||
.long(),
|
||||
src=eplb_model_state.expert_load_window,
|
||||
)
|
||||
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
eplb_model_state.physical_to_logical_map.shape[1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
metadata, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(
|
||||
global_expert_load_windows
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
for eplb_model_state in self.model_states.values():
|
||||
expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, : self.num_valid_physical_experts
|
||||
]
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
dtype=eplb_model_state.expert_load_window.dtype,
|
||||
device=eplb_model_state.expert_load_window.device,
|
||||
)
|
||||
if not execute_shuffle:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = eplb_model_state.physical_to_logical_map
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group, group_src=0
|
||||
)
|
||||
if not execute_shuffle:
|
||||
return global_expert_load_windows
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_windows = global_expert_loads
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map[
|
||||
:, : self.num_valid_physical_experts
|
||||
]
|
||||
.unsqueeze(0)
|
||||
.expand_as(expert_load_window)
|
||||
.long(),
|
||||
src=expert_load_window,
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
eplb_model_state = next(iter(self.model_states.values()))
|
||||
@@ -806,8 +712,10 @@ class EplbState:
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
coordinator = get_ep_group()
|
||||
assert isinstance(coordinator, StatelessGroupCoordinator)
|
||||
tcp_store_group = coordinator.tcp_store_group
|
||||
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_replicas = (
|
||||
num_replicas // ep_group.size() * num_gpus
|
||||
@@ -933,7 +841,6 @@ class EplbState:
|
||||
if self.async_worker is None:
|
||||
self.async_worker = start_async_worker(
|
||||
self,
|
||||
rank_mapping=rank_mapping,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
|
||||
@@ -1089,83 +996,6 @@ class EplbState:
|
||||
model_state.new_logical_to_physical_map = None
|
||||
model_state.new_logical_replica_count = None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
|
||||
num_models = num_models.item()
|
||||
global_expert_loads = []
|
||||
old_global_expert_indices_per_model = []
|
||||
for _ in range(num_models):
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist()
|
||||
)
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0,
|
||||
)
|
||||
global_expert_loads.append(global_expert_load)
|
||||
old_global_expert_indices_per_model.append(old_global_expert_indices)
|
||||
return global_expert_loads, old_global_expert_indices_per_model
|
||||
|
||||
@classmethod
|
||||
def get_eep_state(
|
||||
cls, parallel_config: ParallelConfig
|
||||
) -> tuple[
|
||||
list[torch.Tensor] | None,
|
||||
list[torch.Tensor] | None,
|
||||
dict[int, int] | None,
|
||||
]:
|
||||
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts.item())
|
||||
new_ep_size = get_ep_group().world_size
|
||||
global_expert_loads, old_global_expert_indices_per_model = (
|
||||
EplbState.recv_state()
|
||||
)
|
||||
|
||||
# EP configuration for all models has to be the same so as eplb config
|
||||
num_logical_experts = global_expert_loads[0].shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_local_physical_experts * new_ep_size - num_logical_experts
|
||||
)
|
||||
assert (
|
||||
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
|
||||
== 0
|
||||
)
|
||||
old_ep_size = (
|
||||
old_global_expert_indices_per_model[0].shape[1]
|
||||
// num_local_physical_experts
|
||||
)
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
return (
|
||||
global_expert_loads,
|
||||
old_global_expert_indices_per_model,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""
|
||||
All-reduce a list of tensors.
|
||||
@@ -1203,6 +1033,60 @@ class EplbState:
|
||||
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
|
||||
return self._allreduce_list(load_pass_list)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(
|
||||
cls,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
parallel_config: ParallelConfig,
|
||||
expanded_physical_to_logical: torch.Tensor,
|
||||
num_valid_physical_experts: int,
|
||||
) -> "EplbState":
|
||||
eplb_state = cls(
|
||||
parallel_config=parallel_config,
|
||||
device=device,
|
||||
)
|
||||
eplb_state.add_model(
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
)
|
||||
eplb_state.num_valid_physical_experts = num_valid_physical_experts
|
||||
num_moe_layers = expanded_physical_to_logical.shape[0]
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
|
||||
|
||||
logical_to_physical_map = torch.full(
|
||||
(
|
||||
num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
eplb_model_state.logical_to_physical_map.shape[2],
|
||||
),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
logical_replica_count = torch.zeros(
|
||||
(num_moe_layers, model.num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
|
||||
for layer_idx in range(num_moe_layers):
|
||||
for phys_idx in range(num_physical_experts):
|
||||
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
|
||||
if logical_idx >= 0:
|
||||
replica_idx = logical_replica_count[layer_idx, logical_idx]
|
||||
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
|
||||
phys_idx
|
||||
)
|
||||
logical_replica_count[layer_idx, logical_idx] += 1
|
||||
|
||||
logical_to_physical_map = logical_to_physical_map.to(device)
|
||||
logical_replica_count = logical_replica_count.to(device)
|
||||
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
|
||||
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
|
||||
return eplb_state
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbLayerState:
|
||||
|
||||
@@ -19,6 +19,8 @@ from torch.distributed import (
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -249,10 +251,18 @@ def move_to_buffer(
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
if isinstance(get_ep_group(), StatelessGroupCoordinator):
|
||||
ep_group = get_ep_group()
|
||||
is_stateless = True
|
||||
else:
|
||||
is_stateless = False
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
# Pre-compute global ranks mapping (only needed for non-stateless groups)
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
if not is_stateless:
|
||||
rank_to_global = {
|
||||
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
|
||||
}
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
@@ -284,15 +294,23 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
if is_stateless:
|
||||
for w in expert_weights:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = w[src]
|
||||
op.group_peer = dst
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
|
||||
# 3. Post recvs
|
||||
if recv_count > 0:
|
||||
@@ -321,26 +339,40 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
if is_stateless:
|
||||
for b in expert_weights_buffers:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = b[dst]
|
||||
op.group_peer = src
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return (
|
||||
is_unchanged,
|
||||
|
||||
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
|
||||
self.add_events(other.get_all_events())
|
||||
return self
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
|
||||
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleHiddenStatesConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
|
||||
"ExampleHiddenStatesConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
|
||||
@@ -413,7 +413,20 @@ class TpKVTopology:
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
return remote_tp_size // self.tp_size
|
||||
|
||||
def pp_ratio(
|
||||
self,
|
||||
remote_pp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the pipeline parallel ratio between local and remote PP.
|
||||
"""
|
||||
assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
|
||||
f"Local pipline parallel size {self.tp_size} is not divisible "
|
||||
f"by remote pipline parallel size {remote_pp_size} or vice versa."
|
||||
)
|
||||
return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
@@ -457,6 +470,7 @@ class TpKVTopology:
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
remote_pp_size: int
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
@@ -464,19 +478,36 @@ class TpKVTopology:
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
pp_ratio = self.pp_ratio(remote_pp_size)
|
||||
target_pp_rank_list = []
|
||||
target_tp_rank_list = []
|
||||
if self.pp_size < remote_pp_size:
|
||||
for i in range(pp_ratio):
|
||||
target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
|
||||
else:
|
||||
target_pp_rank_list.append(self.pp_rank // pp_ratio)
|
||||
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
if self.tp_size < remote_tp_size:
|
||||
for i in range(tp_ratio):
|
||||
target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
|
||||
else:
|
||||
target_tp_rank_list.append(self.tp_rank // tp_ratio)
|
||||
|
||||
target_rank_list = []
|
||||
for pp_rank in target_pp_rank_list:
|
||||
for tp_rank in target_tp_rank_list:
|
||||
target_rank = pp_rank * remote_tp_size + tp_rank
|
||||
target_rank_list.append((target_rank, pp_rank, tp_rank))
|
||||
|
||||
return target_rank_list
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
remote_pp_size = self.remote_pp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
|
||||
@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if this connector requires PIECEWISE CUDA graph mode.
|
||||
|
||||
Connectors that use asynchronous layer-by-layer operations
|
||||
(wait_for_layer_load/save_kv_layer) should override this method
|
||||
to return True when those operations are enabled. These operations
|
||||
cannot be captured in CUDA graphs and will be skipped during replay,
|
||||
causing data races. PIECEWISE mode allows Python code to execute
|
||||
between graph pieces, ensuring proper synchronization.
|
||||
|
||||
Args:
|
||||
extra_config: The kv_connector_extra_config dict from
|
||||
KVTransferConfig.
|
||||
|
||||
Returns:
|
||||
True if this connector requires PIECEWISE CUDA graph mode,
|
||||
False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
if isinstance(attn_metadata, dict):
|
||||
inject_kv_into_layer(
|
||||
kv_cache_layer,
|
||||
kv_cache,
|
||||
request.slot_mapping,
|
||||
attn_metadata[layer_name],
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
return layer[block_idxs, :, offsets]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def extract_from_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Extract data from KV cache
|
||||
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
|
||||
"""
|
||||
|
||||
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
|
||||
# shape: [len(slot_mapping), num_heads, head_size]
|
||||
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request ID
|
||||
req_id: str
|
||||
# Request filename
|
||||
filename: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Whether this request is a new request or partially computed already
|
||||
new_req: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool,
|
||||
) -> "ReqMeta":
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
return ReqMeta(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
new_req=new_req,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool = True,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(
|
||||
req_id, filename, token_ids, block_ids, block_size, new_req
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Simple debug implementation of a HiddenStatesConnector.
|
||||
|
||||
Simply extracts the hidden states from the kv cache and stores them to disk.
|
||||
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
# Must be False so that drafter kv cache isn't merged with verifier's
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
self.cache_layers: list[str] = [] # set by self.register_kv_caches
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
assert self._vllm_config.speculative_config is not None, (
|
||||
"ExampleHiddenStatesConnector only works when using "
|
||||
"'extract_hidden_states' speculative method"
|
||||
)
|
||||
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.num_hidden_states = len(
|
||||
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
self._request_filenames: dict[str, str] = {}
|
||||
self._active_requests: dict[str, NewRequestData] = {}
|
||||
self._req_blocks: dict[str, list[int]] = {}
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, *args, **kwargs: Any) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_save(self):
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionLayer,
|
||||
)
|
||||
|
||||
# Filter layers to only include CacheOnlyAttentionLayers
|
||||
layers = get_layers_from_vllm_config(
|
||||
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
|
||||
)
|
||||
self.cache_layers = list(layers.keys())
|
||||
assert len(self.cache_layers) == 1, (
|
||||
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
|
||||
)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
if layer_name not in self.cache_layers:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionMetadata,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
|
||||
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
|
||||
)
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
|
||||
|
||||
os.makedirs(self._storage_path, exist_ok=True)
|
||||
for request in connector_metadata.requests:
|
||||
hidden_states = extract_from_kv_cache(
|
||||
kv_layer, request.slot_mapping, request.token_ids.shape[0]
|
||||
)
|
||||
tensors = {
|
||||
"hidden_states": hidden_states.detach().cpu(),
|
||||
"token_ids": request.token_ids.detach().cpu(),
|
||||
}
|
||||
safetensors.torch.save_file(tensors, request.filename)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# This connector is store-only, so we don't need to load any tokens
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
# Usually used to handle allocation of new blocks for requests that are loading
|
||||
# tokens from connector's external kv cache. We never load from external cache
|
||||
# so this is a no-op.
|
||||
assert num_external_tokens == 0, "This connector is store-only"
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleHiddenStatesConnectorMetadata()
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
|
||||
meta.add_request(
|
||||
new_req.req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._request_filenames[new_req.req_id] = filename
|
||||
self._active_requests[new_req.req_id] = new_req
|
||||
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
if req_id not in self._active_requests:
|
||||
continue
|
||||
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
cached_req = self._active_requests[req_id]
|
||||
req_block_ids = self._req_blocks[req_id]
|
||||
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
req_block_ids.extend(block_ids)
|
||||
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
|
||||
|
||||
meta.add_request(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=cached_req.prompt_token_ids or [],
|
||||
block_ids=req_block_ids,
|
||||
block_size=self._block_size,
|
||||
new_req=False,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
req_filename = self._request_filenames.pop(req_id, None)
|
||||
_ = self._active_requests.pop(req_id, None)
|
||||
_ = self._req_blocks.pop(req_id, None)
|
||||
|
||||
return False, {"hidden_states_path": req_filename}
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
# NHD means we have (num_tokens, num_heads)
|
||||
# HND means we have (num_heads, num_tokens)
|
||||
# For now, we only support NHD layout since this keeps the
|
||||
# hidden states for each token together in memory.
|
||||
# HND is primarily used when sharding heads across devices.
|
||||
return "NHD"
|
||||
@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
LMCache requires PIECEWISE CUDA graph mode when layerwise
|
||||
operations are enabled. The wait_for_layer_load and save_kv_layer
|
||||
methods perform actual async synchronization that cannot be
|
||||
captured in CUDA graphs.
|
||||
"""
|
||||
return extra_config.get("use_layerwise", False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
|
||||
############################################################
|
||||
# Class Methods
|
||||
############################################################
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning_once(
|
||||
"Unable to detect current VLLM config. "
|
||||
"Fallback to default kv cache layout."
|
||||
)
|
||||
return None
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if use_mla:
|
||||
# return None when we have mla
|
||||
# as the layout should not matter in that case,
|
||||
# which fallback to the default behavior.
|
||||
return None
|
||||
logger.info_once(
|
||||
"MooncakeConnector setting KV cache layout to HND for better xfer performance."
|
||||
)
|
||||
return "HND"
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "HND":
|
||||
kernel_block_size = cache.shape[-2]
|
||||
else:
|
||||
kernel_block_size = cache.shape[-3]
|
||||
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
MultiConnector requires PIECEWISE CUDA graph mode if any of its
|
||||
child connectors require it.
|
||||
"""
|
||||
connectors_config = extra_config.get("connectors", [])
|
||||
for conn_config in connectors_config:
|
||||
temp_ktc = KVTransferConfig(**conn_config)
|
||||
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
|
||||
child_extra_config = conn_config.get("kv_connector_extra_config", {})
|
||||
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
import ixformer.distributed as ixfd
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
@@ -327,6 +332,8 @@ class GroupCoordinator:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
|
||||
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
@@ -339,7 +346,7 @@ class GroupCoordinator:
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if self.rank in ranks:
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group)
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
@@ -372,8 +379,7 @@ class GroupCoordinator:
|
||||
self.device_communicator = device_comm_cls(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
# device_group=self.device_group,
|
||||
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
|
||||
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
)
|
||||
|
||||
@@ -385,11 +391,6 @@ class GroupCoordinator:
|
||||
self.cpu_group, 1 << 22, 6
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# self.use_custom_op_call = (
|
||||
# current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
# )
|
||||
self.use_custom_op_call = False
|
||||
|
||||
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
|
||||
@@ -468,14 +469,12 @@ class GroupCoordinator:
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
# from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
# CudaCommunicator,
|
||||
# )
|
||||
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator,
|
||||
)
|
||||
|
||||
if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
|
||||
assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
@@ -608,9 +607,9 @@ class GroupCoordinator:
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
else:
|
||||
torch.distributed.broadcast(input_,
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
torch.distributed.broadcast(
|
||||
input_, src=self.ranks[src], group=self.device_group
|
||||
)
|
||||
return input_
|
||||
|
||||
def broadcast_object(self, obj: Any | None = None, src: int = 0):
|
||||
@@ -764,10 +763,9 @@ class GroupCoordinator:
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
@@ -802,10 +800,8 @@ class GroupCoordinator:
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
@@ -876,6 +872,10 @@ class GroupCoordinator:
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -893,10 +893,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
@@ -917,6 +913,7 @@ class GroupCoordinator:
|
||||
handle = torch.distributed.isend(
|
||||
tensor, dst=self.ranks[dst], group=comm_group
|
||||
)
|
||||
|
||||
if tensor.is_cuda:
|
||||
tensor.record_stream(torch.cuda.current_stream(tensor.device))
|
||||
handles.append(handle)
|
||||
@@ -973,6 +970,11 @@ class GroupCoordinator:
|
||||
]:
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None, [], []
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -990,10 +992,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
handles: list[Handle] = []
|
||||
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
if hasattr(self, "device_group"):
|
||||
# torch.distributed.destroy_process_group(self.device_group)
|
||||
if self.device_group is not None:
|
||||
if self.device_communicator and self.device_communicator.use_vllm_comm:
|
||||
ixfd.destroy_process_group(self.device_group)
|
||||
else:
|
||||
torch.distributed.destroy_process_group(self.device_group)
|
||||
del self.device_group
|
||||
if hasattr(self, "cpu_group"):
|
||||
self.device_group = None
|
||||
if self.cpu_group is not None:
|
||||
torch.distributed.destroy_process_group(self.cpu_group)
|
||||
del self.cpu_group
|
||||
if self.device_communicator is not None:
|
||||
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch_router_logits(
|
||||
hidden_states,
|
||||
extra_residual,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
|
||||
)
|
||||
|
||||
|
||||
def _init_stateless_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
group_ports: list[list[int]],
|
||||
host: str,
|
||||
backend: str,
|
||||
use_device_communicator: bool = True,
|
||||
) -> "StatelessGroupCoordinator":
|
||||
"""Create a StatelessGroupCoordinator with the given parameters."""
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
world = get_world_group()
|
||||
return StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=world.local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=use_device_communicator,
|
||||
group_name=group_name,
|
||||
host=host,
|
||||
group_ports=group_ports,
|
||||
global_rank=world.rank,
|
||||
global_world_size=world.world_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_active_groups(
|
||||
*,
|
||||
world: GroupCoordinator | None,
|
||||
dp: GroupCoordinator | None,
|
||||
ep: GroupCoordinator | None,
|
||||
eplb: GroupCoordinator | None,
|
||||
node_count: int | None,
|
||||
) -> None:
|
||||
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
|
||||
|
||||
Destruction is collective — all ranks in the old groups must call this
|
||||
function together. Pass all-``None`` to tear down without replacement.
|
||||
"""
|
||||
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
|
||||
for group in (_DP, _EP, _WORLD, _EPLB):
|
||||
if group is not None:
|
||||
group.destroy()
|
||||
_WORLD = world
|
||||
_DP = dp
|
||||
_EP = ep
|
||||
_EPLB = eplb
|
||||
_NODE_COUNT = node_count
|
||||
|
||||
|
||||
_TP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def _init_elastic_ep_world(
|
||||
config, local_rank: int, backend: str, rank: int, world_size: int
|
||||
) -> None:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
global _WORLD, _NODE_COUNT
|
||||
assert _WORLD is None, "world group already initialized"
|
||||
parallel_config = config.parallel_config
|
||||
global_rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
global_world_size = parallel_config.world_size_across_dp
|
||||
all_ranks = list(range(global_world_size))
|
||||
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
|
||||
if global_rank in all_ranks:
|
||||
group_ranks = [all_ranks]
|
||||
group_ports = [parallel_config.get_next_stateless_world_group_port()]
|
||||
world = StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
host=parallel_config.data_parallel_master_ip,
|
||||
group_ports=group_ports,
|
||||
global_rank=global_rank,
|
||||
global_world_size=global_world_size,
|
||||
)
|
||||
assert parallel_config.nnodes_within_dp == 1, (
|
||||
"Elastic EP is not supported with multi-node TP/PP"
|
||||
)
|
||||
_NODE_COUNT = _node_count(world.tcp_store_group)
|
||||
_WORLD = world
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
and not enable_elastic_ep
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
tp_pp_cpu_group = torch.distributed.new_group(
|
||||
backend="gloo", timeout=timeout
|
||||
)
|
||||
if _node_count(tp_pp_cpu_group) > 1:
|
||||
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
|
||||
# to initialize all DP/EP groups, hence all ranks within TP/PP group
|
||||
# must reside on the same node
|
||||
raise RuntimeError(
|
||||
"Elastic EP is not yet supported with multi-node TP/PP"
|
||||
)
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if enable_elastic_ep:
|
||||
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
|
||||
return
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
config = get_current_vllm_config()
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
enable_elastic_ep = config.parallel_config.enable_elastic_ep
|
||||
if enable_elastic_ep:
|
||||
# Use stateless world group for global information
|
||||
world_size = get_world_group().world_size
|
||||
rank = get_world_group().rank
|
||||
backend = backend or "nccl"
|
||||
tp_pp_pcp_size = (
|
||||
tensor_model_parallel_size
|
||||
* pipeline_model_parallel_size
|
||||
* prefill_context_model_parallel_size
|
||||
)
|
||||
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
tensor_model_parallel_size,
|
||||
)
|
||||
else:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group
|
||||
)
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
|
||||
assert _TP is None, "tensor model parallel group is already initialized"
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size
|
||||
).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(1, 2)
|
||||
.reshape(-1, prefill_context_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PCP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||
)
|
||||
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
|
||||
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(0, 2)
|
||||
.reshape(-1, pipeline_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp"
|
||||
)
|
||||
@@ -1523,14 +1655,27 @@ def initialize_model_parallel(
|
||||
assert _DP is None, "data parallel group is already initialized"
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
dp_ports = [
|
||||
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
|
||||
]
|
||||
_DP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"dp",
|
||||
dp_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
|
||||
global _EP
|
||||
assert _EP is None, "expert parallel group is already initialized"
|
||||
# Don't create EP group for dense models.
|
||||
if config is None or config.model_config is None or config.model_config.is_moe:
|
||||
if config.model_config is None or config.model_config.is_moe:
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(
|
||||
@@ -1542,9 +1687,22 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
ep_ports = [
|
||||
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
|
||||
]
|
||||
_EP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"ep",
|
||||
ep_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
|
||||
# Create EPLB group with the same ranks as EP if EPLB is enabled.
|
||||
# This is a separate process group to isolate EPLB communications
|
||||
@@ -1557,10 +1715,25 @@ def initialize_model_parallel(
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
# Reuse the same group_ranks from EP
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
eplb_ports = [
|
||||
parallel_config.get_next_stateless_eplb_group_port()
|
||||
for _ in group_ranks
|
||||
]
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
"eplb",
|
||||
eplb_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="eplb",
|
||||
)
|
||||
# If no EP group needed, _EP remains None
|
||||
# If no EPLB group needed, _EPLB remains None
|
||||
|
||||
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
world_group = get_world_group()
|
||||
if hasattr(world_group, "backend"):
|
||||
backend = backend or world_group.backend
|
||||
else:
|
||||
backend = backend or torch.distributed.get_backend(world_group.device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size,
|
||||
|
||||
322
vllm/distributed/stateless_coordinator.py
Normal file
322
vllm/distributed/stateless_coordinator.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
TensorMetadata,
|
||||
_get_unique_name,
|
||||
_register_group,
|
||||
_split_tensor_dict,
|
||||
)
|
||||
from vllm.distributed.utils import (
|
||||
StatelessProcessGroup,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
stateless_init_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StatelessGroupCoordinator(GroupCoordinator):
|
||||
"""
|
||||
A stateless version of the GroupCoordinator class in parallel_state,
|
||||
It will create CPU, device and TCPStore based communication groups
|
||||
that are independent of PyTorch's WORLD group. Hence,
|
||||
communication groups with a different set of participants GPUs
|
||||
can be created without destroying the existing ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: str | Backend,
|
||||
use_device_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
host: str = "127.0.0.1",
|
||||
group_ports: list[list[int]] | None = None,
|
||||
global_rank: int = 0,
|
||||
global_world_size: int = 1,
|
||||
):
|
||||
group_name = group_name or "anonymous"
|
||||
self.unique_name = _get_unique_name(group_name)
|
||||
_register_group(self)
|
||||
|
||||
self.rank = global_rank
|
||||
self.local_rank = local_rank
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
self_tcp_store_group = None
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
backend = str(torch_distributed_backend)
|
||||
self.backend = backend
|
||||
assert group_ports is not None, "group_ports is not provided"
|
||||
for idx, ranks in enumerate(group_ranks):
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
|
||||
ports = group_ports[idx]
|
||||
device_port = ports[0]
|
||||
cpu_port = ports[1]
|
||||
tcp_store_port = ports[2]
|
||||
|
||||
device_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=device_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend=backend,
|
||||
group_name=f"{self.unique_name}_device",
|
||||
)
|
||||
cpu_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=cpu_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend="gloo",
|
||||
group_name=f"{self.unique_name}_cpu",
|
||||
)
|
||||
tcp_store_group = StatelessProcessGroup.create(
|
||||
host=host,
|
||||
port=tcp_store_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
self_tcp_store_group = tcp_store_group
|
||||
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
assert self_tcp_store_group is not None
|
||||
|
||||
self.cpu_group = self_cpu_group
|
||||
self.device_group = self_device_group
|
||||
self.tcp_store_group = self_tcp_store_group
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
elif current_platform.is_xpu():
|
||||
self.device = torch.device(f"xpu:{local_rank}")
|
||||
elif current_platform.is_out_of_tree():
|
||||
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
current_platform.get_device_communicator_cls()
|
||||
)
|
||||
assert device_comm_cls == CudaCommunicator
|
||||
self.device_communicator = CudaCommunicator(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
device_group=self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
global_ranks=self.ranks,
|
||||
global_world_size=global_world_size,
|
||||
tcp_store_group=self.tcp_store_group,
|
||||
)
|
||||
|
||||
self.mq_broadcaster = None
|
||||
|
||||
self.use_custom_op_call = (
|
||||
current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
)
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def destroy(self):
|
||||
if self.device_communicator:
|
||||
self.device_communicator.destroy()
|
||||
if self.device_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.device_group)
|
||||
if self.cpu_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.cpu_group)
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the world size of this group."""
|
||||
return self.world_size
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator and input_.is_cuda:
|
||||
return self.device_communicator.broadcast(input_, src)
|
||||
else:
|
||||
return self.tcp_store_group.broadcast(input_, src)
|
||||
|
||||
def broadcast_object(self, obj=None, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return obj
|
||||
return self.tcp_store_group.broadcast_obj(obj, src)
|
||||
|
||||
def broadcast_object_list(
|
||||
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
|
||||
):
|
||||
assert src < self.world_size
|
||||
|
||||
if self.world_size == 1:
|
||||
return obj_list
|
||||
|
||||
if self.rank_in_group == src:
|
||||
for obj in obj_list:
|
||||
self.tcp_store_group.broadcast_obj(obj, src)
|
||||
else:
|
||||
for i in range(len(obj_list)):
|
||||
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
|
||||
|
||||
return obj_list
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
|
||||
src: int = 0,
|
||||
group: ProcessGroup | None = None,
|
||||
metadata_group: ProcessGroup | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if self.rank_in_group == src:
|
||||
assert isinstance(tensor_dict, dict), (
|
||||
f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
)
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
else:
|
||||
metadata_list = None
|
||||
tensor_list = []
|
||||
|
||||
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
|
||||
metadata_list, src
|
||||
)
|
||||
|
||||
if self.rank_in_group != src:
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(
|
||||
value.size, dtype=value.dtype, device=value.device
|
||||
)
|
||||
tensor_list.append(tensor)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor.copy_(self.device_communicator.broadcast(tensor, src))
|
||||
else:
|
||||
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
|
||||
|
||||
return tensor_dict
|
||||
|
||||
def send_object(self, obj, dst: int) -> None:
|
||||
assert dst < self.world_size
|
||||
assert dst != self.rank_in_group
|
||||
self.tcp_store_group.send_obj(obj, dst)
|
||||
|
||||
def recv_object(self, src: int):
|
||||
assert src < self.world_size
|
||||
assert src != self.rank_in_group
|
||||
return self.tcp_store_group.recv_obj(src)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.tcp_store_group.send_obj(metadata_list, dst)
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
self.device_communicator.send(tensor, dst)
|
||||
else:
|
||||
self.tcp_store_group.send(tensor, dst)
|
||||
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return None
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size
|
||||
|
||||
recv_metadata_list = self.tcp_store_group.recv_obj(src)
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
|
||||
if tensor.numel() > 0:
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor = self.device_communicator.recv(
|
||||
tensor.size(), tensor.dtype, src
|
||||
)
|
||||
else:
|
||||
tensor = self.tcp_store_group.recv(tensor, src)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
def barrier(self):
|
||||
self.tcp_store_group.barrier()
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
|
||||
gathered_list[self.rank_in_group] = input_
|
||||
for src_rank in range(self.world_size):
|
||||
if src_rank != self.rank_in_group:
|
||||
gathered_list[src_rank] = self.device_communicator.recv(
|
||||
input_.size(), input_.dtype, src_rank
|
||||
)
|
||||
return torch.cat(gathered_list, dim=dim)
|
||||
else:
|
||||
self.device_communicator.send(input_, dst)
|
||||
return None
|
||||
@@ -18,7 +18,7 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed import ProcessGroup, Store, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all other ranks."""
|
||||
if self.rank == src:
|
||||
tensor_bytes = pickle.dumps(tensor)
|
||||
self.expire_data()
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, tensor_bytes)
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return tensor
|
||||
else:
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
tensor = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int):
|
||||
"""Send a tensor to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(tensor))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Receive a tensor from a source rank."""
|
||||
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
|
||||
received = pickle.loads(self.store.get(key))
|
||||
self.recv_src_counter[src] += 1
|
||||
tensor.copy_(received)
|
||||
return tensor
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
|
||||
) -> torch.Tensor:
|
||||
"""All-reduce a tensor across all ranks."""
|
||||
tensors = self.all_gather_obj(tensor)
|
||||
result = tensors[0].clone()
|
||||
for t in tensors[1:]:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
result.add_(t)
|
||||
elif op == torch.distributed.ReduceOp.PRODUCT:
|
||||
result.mul_(t)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
result = torch.maximum(result, t)
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
result = torch.minimum(result, t)
|
||||
return result
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
@@ -448,8 +497,14 @@ def init_gloo_process_group(
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
group_name: str | None = None,
|
||||
return_store: bool = False,
|
||||
) -> ProcessGroup | tuple[ProcessGroup, Store]:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
|
||||
if backend == "gloo":
|
||||
pg = init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
pg = current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if group_name is not None:
|
||||
from torch._C._distributed_c10d import _register_process_group
|
||||
|
||||
pg._set_group_name(group_name)
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
if return_store:
|
||||
return pg, store
|
||||
else:
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | Any,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers.
|
||||
|
||||
This is a static method that can be called from the trainer process
|
||||
to send weights to all inference workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
The tensors should be on the appropriate device for the backend.
|
||||
trainer_args: Dictionary containing backend-specific arguments needed
|
||||
to send weights. The structure depends on the backend:
|
||||
- NCCL: Contains 'group', 'src', 'packed', etc.
|
||||
- IPC: Contains 'mode' ('http' or 'ray'),
|
||||
'llm_handle' (for Ray), 'url' (for HTTP), etc.
|
||||
|
||||
Example:
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> engine.trainer_send_weights(param_iter, trainer_args)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
|
||||
"vllm.distributed.weight_transfer.nccl_engine",
|
||||
"NCCLWeightTransferEngine",
|
||||
)
|
||||
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"ipc",
|
||||
"vllm.distributed.weight_transfer.ipc_engine",
|
||||
"IPCWeightTransferEngine",
|
||||
)
|
||||
|
||||
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""IPC-based weight transfer engine using CUDA IPC for communication."""
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferEngine,
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCTrainerSendWeightsArgs:
|
||||
"""Arguments for IPC trainer_send_weights method."""
|
||||
|
||||
mode: str
|
||||
"""Transport mode: 'http' or 'ray'."""
|
||||
llm_handle: Any = None
|
||||
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
|
||||
url: str | None = None
|
||||
"""Base URL for HTTP endpoint (required for 'http' mode)."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that required arguments are provided for the selected mode."""
|
||||
if self.mode == "ray" and self.llm_handle is None:
|
||||
raise ValueError("llm_handle is required for 'ray' mode")
|
||||
if self.mode == "http" and self.url is None:
|
||||
raise ValueError("url is required for 'http' mode")
|
||||
if self.mode not in ("ray", "http"):
|
||||
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for IPC weight transfer backend.
|
||||
|
||||
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
|
||||
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
|
||||
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
|
||||
it is unpickled into ``ipc_handles`` during ``__post_init__``.
|
||||
"""
|
||||
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
|
||||
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
|
||||
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
|
||||
ipc_handles_pickled: str | None = None
|
||||
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ipc_handles_pickled is not None:
|
||||
if self.ipc_handles is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||
)
|
||||
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
||||
self.ipc_handles_pickled = None
|
||||
|
||||
if self.ipc_handles is None:
|
||||
raise ValueError(
|
||||
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
|
||||
)
|
||||
|
||||
num_params = len(self.names)
|
||||
if len(self.dtype_names) != num_params:
|
||||
raise ValueError(
|
||||
f"`dtype_names` should be of the same size as `names`: "
|
||||
f"got {len(self.dtype_names)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.shapes) != num_params:
|
||||
raise ValueError(
|
||||
f"`shapes` should be of the same size as `names`: "
|
||||
f"got {len(self.shapes)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.ipc_handles) != num_params:
|
||||
raise ValueError(
|
||||
f"`ipc_handles` should be of the same size as `names`: "
|
||||
f"got {len(self.ipc_handles)} and {len(self.names)}"
|
||||
)
|
||||
|
||||
|
||||
class IPCWeightTransferEngine(
|
||||
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
|
||||
):
|
||||
"""
|
||||
Weight transfer engine using CUDA IPC for communication between trainer and workers.
|
||||
|
||||
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
|
||||
to all inference workers in a process group. IPC handles are used to share
|
||||
memory between processes on the same node.
|
||||
"""
|
||||
|
||||
# Define backend-specific dataclass types
|
||||
init_info_cls = IPCWeightTransferInitInfo
|
||||
update_info_cls = IPCWeightTransferUpdateInfo
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the IPC weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
super().__init__(config, parallel_config)
|
||||
|
||||
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
This is called once at the beginning of training.
|
||||
No initialization needed for IPC backend.
|
||||
|
||||
Args:
|
||||
init_info: IPC initialization info (empty)
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: IPCWeightTransferUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from the trainer via CUDA IPC handles.
|
||||
|
||||
Args:
|
||||
update_info: IPC update info containing parameter names, dtypes, shapes,
|
||||
and IPC handles. Each IPC handle is a mapping between physical
|
||||
GPU UUID and the IPC handle tuple (func, args).
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
assert update_info.ipc_handles is not None
|
||||
weights = []
|
||||
for name, _dtype_name, _shape, ipc_handle in zip(
|
||||
update_info.names,
|
||||
update_info.dtype_names,
|
||||
update_info.shapes,
|
||||
update_info.ipc_handles,
|
||||
):
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
|
||||
handle = ipc_handle[physical_gpu_id]
|
||||
|
||||
func, args = handle
|
||||
list_args = list(args) # type: ignore
|
||||
# Index 6 is the device_index parameter in torch's
|
||||
# IPC handle tuple (rebuild_cuda_tensor). Update it
|
||||
# to the current device since the logical index can
|
||||
# differ between sender and receiver.
|
||||
list_args[6] = device_index
|
||||
weight = func(*list_args) # type: ignore
|
||||
weights.append((name, weight))
|
||||
|
||||
load_weights(weights)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers via CUDA IPC.
|
||||
|
||||
Supports two modes:
|
||||
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
|
||||
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
Tensors should be on the same GPU as the inference workers.
|
||||
trainer_args: Dictionary containing IPC-specific arguments.
|
||||
Should contain keys from IPCTrainerSendWeightsArgs:
|
||||
- mode: 'ray' or 'http'
|
||||
- llm_handle: Ray ObjectRef (for 'ray' mode)
|
||||
- url: Base URL string (for 'http' mode)
|
||||
|
||||
Example (Ray mode):
|
||||
>>> from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
... IPCWeightTransferEngine,
|
||||
... IPCTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
|
||||
Example (HTTP mode):
|
||||
>>> args = IPCTrainerSendWeightsArgs(
|
||||
... mode="http", url="http://localhost:8000"
|
||||
... )
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
"""
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = IPCTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
# Get physical GPU UUID
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
gpu_uuid = str(props.uuid)
|
||||
|
||||
# Collect weight metadata and create IPC handles
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
ipc_handles = []
|
||||
|
||||
for name, tensor in iterator:
|
||||
names.append(name)
|
||||
dtype_names.append(str(tensor.dtype).split(".")[-1])
|
||||
shapes.append(list(tensor.shape))
|
||||
|
||||
# Create IPC handle for this weight tensor
|
||||
# The tensor must remain in memory for IPC to work
|
||||
weight = tensor.detach().contiguous()
|
||||
ipc_handle = reduce_tensor(weight)
|
||||
ipc_handles.append({gpu_uuid: ipc_handle})
|
||||
|
||||
# Send weights based on mode
|
||||
if args.mode == "ray":
|
||||
# Ray mode: send via Ray RPC
|
||||
import ray
|
||||
|
||||
update_info = asdict(
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
)
|
||||
ray.get(
|
||||
args.llm_handle.update_weights.remote(dict(update_info=update_info))
|
||||
)
|
||||
elif args.mode == "http":
|
||||
# HTTP mode: send via HTTP POST with pickled handles
|
||||
# Pickle and base64 encode IPC handles for HTTP transmission
|
||||
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
url = f"{args.url}/update_weights"
|
||||
payload = {
|
||||
"update_info": {
|
||||
"names": names,
|
||||
"dtype_names": dtype_names,
|
||||
"shapes": shapes,
|
||||
"ipc_handles_pickled": pickled_handles,
|
||||
}
|
||||
}
|
||||
response = requests.post(url, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
world_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLTrainerSendWeightsArgs:
|
||||
"""Arguments for NCCL trainer_send_weights method."""
|
||||
|
||||
group: Any
|
||||
"""Process group (PyNcclCommunicator) for NCCL communication."""
|
||||
src: int = 0
|
||||
"""Source rank (default 0, trainer is typically rank 0)."""
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
|
||||
"""Optional function to apply to each (name, tensor) pair before broadcasting.
|
||||
If None, extracts just the tensor."""
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
stream: torch.cuda.Stream | None = None
|
||||
"""CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for NCCL weight transfer backend."""
|
||||
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int = 0,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
||||
| None = None,
|
||||
packed: bool = False,
|
||||
stream: torch.cuda.Stream | None = None,
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""Broadcast weights from trainer to vLLM workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (default 0, trainer is typically rank 0)
|
||||
post_iter_func: Optional function to apply to each (name, tensor) pair
|
||||
before broadcasting. If None, extracts just the tensor.
|
||||
packed: Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before
|
||||
broadcasting to reduce NCCL communication overhead.
|
||||
stream: CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer.
|
||||
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
packed_num_buffers: Number of buffers for double/triple buffering.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
|
||||
NCCL-specific arguments. If a dict, should contain keys from
|
||||
NCCLTrainerSendWeightsArgs.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... NCCLTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
||||
... param_iter, group, packed=True
|
||||
... )
|
||||
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
|
||||
"""
|
||||
if post_iter_func is None:
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = NCCLTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
if args.post_iter_func is None:
|
||||
# Default: extract just the tensor from (name, tensor) tuple
|
||||
post_iter_func = lambda x: x[1]
|
||||
else:
|
||||
post_iter_func = args.post_iter_func
|
||||
|
||||
if packed:
|
||||
if args.packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=group,
|
||||
src=src,
|
||||
group=args.group,
|
||||
src=args.src,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=packed_buffer_size_bytes,
|
||||
num_buffers=packed_num_buffers,
|
||||
buffer_size_bytes=args.packed_buffer_size_bytes,
|
||||
num_buffers=args.packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for item in iterator:
|
||||
tensor = post_iter_func(item)
|
||||
group.broadcast(
|
||||
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
||||
args.group.broadcast(
|
||||
tensor,
|
||||
src=args.src,
|
||||
stream=args.stream or torch.cuda.current_stream(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user