Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
import contextlib
import gc
import os
import pickle
import weakref
from collections import namedtuple
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
import ixformer.distributed as ixfd
import vllm._custom_ops as ops
@@ -327,6 +332,8 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
self_device_group = None
self_cpu_group = None
@@ -339,7 +346,7 @@ class GroupCoordinator:
with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ixfd_group = ixfd.init_comm_with_store(device_group)
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
@@ -372,8 +379,7 @@ class GroupCoordinator:
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
# device_group=self.device_group,
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
unique_name=self.unique_name,
)
@@ -385,11 +391,6 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6
)
from vllm.platforms import current_platform
# self.use_custom_op_call = (
# current_platform.is_cuda_alike() or current_platform.is_tpu()
# )
self.use_custom_op_call = False
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
@@ -468,14 +469,12 @@ class GroupCoordinator:
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
# from vllm.distributed.device_communicators.cuda_communicator import (
# CudaCommunicator,
# )
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
if self.device_communicator is not None:
# assert isinstance(self.device_communicator, CudaCommunicator)
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
@@ -608,9 +607,9 @@ class GroupCoordinator:
src=self.ranks[src],
group=self.device_group)
else:
torch.distributed.broadcast(input_,
src=self.ranks[src],
group=self.device_group)
torch.distributed.broadcast(
input_, src=self.ranks[src], group=self.device_group
)
return input_
def broadcast_object(self, obj: Any | None = None, src: int = 0):
@@ -764,10 +763,9 @@ class GroupCoordinator:
group=group,
async_op=True)
else:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
@@ -802,10 +800,8 @@ class GroupCoordinator:
async_op=True)
else:
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
@@ -876,6 +872,10 @@ class GroupCoordinator:
if self.world_size <= 1:
return []
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -893,10 +893,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.send_object(metadata_list, dst=dst)
@@ -917,6 +913,7 @@ class GroupCoordinator:
handle = torch.distributed.isend(
tensor, dst=self.ranks[dst], group=comm_group
)
if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))
handles.append(handle)
@@ -973,6 +970,11 @@ class GroupCoordinator:
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -990,10 +992,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if hasattr(self, "device_group"):
# torch.distributed.destroy_process_group(self.device_group)
if self.device_group is not None:
if self.device_communicator and self.device_communicator.use_vllm_comm:
ixfd.destroy_process_group(self.device_group)
else:
torch.distributed.destroy_process_group(self.device_group)
del self.device_group
if hasattr(self, "cpu_group"):
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
del self.cpu_group
if self.device_communicator is not None:
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits(
hidden_states,
extra_residual,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, extra_residual, router_logits
return hidden_states, router_logits
def dispatch(
self,
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
)
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective — all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
from vllm.config import get_current_vllm_config
config = get_current_vllm_config_or_none()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
@@ -1523,14 +1655,27 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
@@ -1542,9 +1687,22 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
# Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications
@@ -1557,10 +1715,25 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
)
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,