Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

View File

@@ -0,0 +1,529 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
CompilationMode,
set_current_vllm_config,
)
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
create_standby_groups,
get_standby_dp_group,
get_standby_ep_group,
pop_standby_groups,
)
from vllm.distributed.parallel_state import (
_replace_active_groups,
prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
logger = init_logger(__name__)
def batch_transfer_weights(
model: nn.Module,
is_sender: bool,
peer_rank: int,
dp_group: StatelessGroupCoordinator,
expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
device_comm = dp_group.device_communicator
if device_comm is None:
raise ValueError("No device communicator found")
expert_weights_set = set()
for weight_group in expert_weights:
for weight in weight_group:
expert_weights_set.add(weight.data_ptr())
state_dict = model.state_dict()
all_params = []
for name, param in state_dict.items():
if name.endswith("expert_map"):
continue
if param.data_ptr() not in expert_weights_set:
all_params.append(param.data)
assert len(all_params) > 0
p2p_ops = []
for param in all_params:
op = object.__new__(P2POp)
if is_sender:
op.op = torch.distributed.isend
op.tensor = param
else:
op.op = torch.distributed.irecv
op.tensor = param
op.group_peer = peer_rank
p2p_ops.append(op)
device_comm.batch_isend_irecv(p2p_ops)
def broadcast_expert_mapping(
physical_to_logical: torch.Tensor | None,
num_local_physical_experts: int | None,
num_logical_experts: int | None,
dp_group: StatelessGroupCoordinator,
device: torch.device,
src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
if dp_group.rank_in_group == src_rank:
assert physical_to_logical is not None
assert num_local_physical_experts is not None
assert num_logical_experts is not None
assert physical_to_logical.dtype == torch.int64
shape_tensor = torch.tensor(
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
)
metadata_tensor = torch.tensor(
[num_local_physical_experts, num_logical_experts],
dtype=torch.int64,
device="cpu",
)
else:
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
if dp_group.rank_in_group != src_rank:
assert device is not None
physical_to_logical = torch.empty(
tuple(shape_tensor.tolist()),
dtype=torch.int64,
device=device,
)
assert physical_to_logical is not None
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
num_local_physical_experts = int(metadata_tensor[0].item())
num_logical_experts = int(metadata_tensor[1].item())
return physical_to_logical, num_local_physical_experts, num_logical_experts
class ElasticEPScalingExecutor:
def __init__(self, worker):
self.worker_ref = weakref.ref(worker)
self.reconfig_request = None
@property
def worker(self):
worker = self.worker_ref()
if worker is None:
raise RuntimeError("Worker has been garbage collected")
return worker
def execute(self, execute_method: str, *args, **kwargs):
method = getattr(self, execute_method, None)
if method is None:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
updated_config.parallel_config = copy.deepcopy(
self.worker.vllm_config.parallel_config
)
updated_config.parallel_config.data_parallel_size = new_dp_size
with set_current_vllm_config(updated_config):
create_standby_groups(
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
# Broadcast old_dp_size to all workers in standby group
if standby_dp_group.rank_in_group < old_dp_size:
old_dp_size_tensor = torch.tensor(
[old_dp_size], dtype=torch.int64, device="cpu"
)
else:
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
old_dp_size_tensor, 0
)
num_new_workers = new_dp_size - old_dp_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Sender-receiver pairing: the first new_workers % old_dp_size
# senders get (k+1) contiguous receivers, the rest get k
# receivers.
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if dp_rank < remainder:
recv_begin = dp_rank * (num_dst_per_sender + 1)
recv_end = recv_begin + num_dst_per_sender + 1
else:
recv_begin = (
remainder * (num_dst_per_sender + 1)
+ (dp_rank - remainder) * num_dst_per_sender
)
recv_end = recv_begin + num_dst_per_sender
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
model = self.worker.model_runner.get_model()
for new_worker_rank in sorted(ranks_to_send):
batch_transfer_weights(
model=model,
is_sender=True,
peer_rank=new_worker_rank,
dp_group=standby_dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
model_config = self.worker.model_runner.model_config
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
physical_to_logical = eplb_model_state.physical_to_logical_map
num_physical_experts = physical_to_logical.shape[1]
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
broadcast_expert_mapping(
physical_to_logical=physical_to_logical,
num_local_physical_experts=num_local_physical_experts,
num_logical_experts=num_logical_experts,
dp_group=standby_dp_group,
src_rank=0,
device=self.worker.device,
)
def switch_and_remove(self) -> None:
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
reconfig_request = self.reconfig_request
assert reconfig_request is not None
new_dp_size = reconfig_request.new_data_parallel_size
new_ep_size = get_ep_group().world_size
parallel_config.data_parallel_size = new_dp_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
# Reconfigure MoE modules with new EP size
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
# Update EPLB state
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
num_physical_experts = num_local_experts * new_ep_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
old_physical_to_logical = eplb_model_state.physical_to_logical_map
num_moe_layers = old_physical_to_logical.shape[0]
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
if new_dp_size > old_dp_size:
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_experts * new_ep_size),
-1,
dtype=old_physical_to_logical.dtype,
device=old_physical_to_logical.device,
)
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
old_physical_to_logical
)
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
pad_size = num_physical_experts - old_num_physical_experts
if new_dp_size > old_dp_size:
assert pad_size > 0
expanded_expert_load_pass = F.pad(
eplb_model_state.expert_load_pass, (0, pad_size), value=0
)
expanded_expert_load_window = F.pad(
eplb_model_state.expert_load_window, (0, pad_size), value=0
)
eplb_model_state.expert_load_pass = expanded_expert_load_pass
eplb_model_state.expert_load_window = expanded_expert_load_window
eplb_state.num_valid_physical_experts = old_num_physical_experts
else:
assert pad_size < 0
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
:, :num_physical_experts
]
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
:, :, :num_physical_experts
]
eplb_state.num_valid_physical_experts = num_physical_experts
model = self.worker.model_runner.get_model()
model.expert_weights = []
with set_current_vllm_config(self.worker.vllm_config):
model.set_eplb_state(
eplb_model_state.expert_load_pass,
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
)
# Force re-creation of the modular kernel (and all2all manager)
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model)
if (
self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
# NOTE(yongji): when using stock torch.compile,
# torch.compile is triggered during GPUModelRunner's load_model()
# TODO(yongji):check do we need to re-trigger torch.compile here?
# any changes to the tensor shapes in execution should already
# be handled internally by torch.compile.
backend = self.worker.vllm_config.compilation_config.init_backend(
self.worker.vllm_config
)
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
saved_block_tables.append(
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
for bt, (saved_gpu, saved_cpu) in zip(
multi_block_table.block_tables, saved_block_tables
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = (
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
new_dp_size = dp_group.world_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Receive old_dp_size broadcasted during transfer_weights
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
old_dp_size = int(old_dp_size_tensor[0].item())
# Calculate which existing worker will send to this new worker
num_new_workers = new_dp_size - old_dp_size
new_worker_idx = dp_rank - old_dp_size
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if new_worker_idx < remainder * (num_dst_per_sender + 1):
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
else:
sender_rank = (
remainder
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
// num_dst_per_sender
)
model = self.worker.model_runner.get_model()
batch_transfer_weights(
model=model,
is_sender=False,
peer_rank=sender_rank,
dp_group=dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
physical_to_logical, num_local_physical_experts, num_logical_experts = (
broadcast_expert_mapping(
physical_to_logical=None,
num_local_physical_experts=None,
num_logical_experts=None,
dp_group=dp_group,
src_rank=0,
device=self.worker.device,
)
)
num_moe_layers = physical_to_logical.shape[0]
new_dp_size = get_dp_group().world_size
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
new_ep_size = new_dp_size * tp_size
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_physical_experts * new_ep_size),
-1,
dtype=physical_to_logical.dtype,
device=physical_to_logical.device,
)
old_num_physical_experts = physical_to_logical.shape[1]
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
return (
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
)
def prepare_new_worker(self) -> None:
with set_current_vllm_config(self.worker.vllm_config):
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())

View File

@@ -0,0 +1,563 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
import torch.distributed
from vllm.config import ParallelConfig
from vllm.distributed import (
sched_yield,
stateless_destroy_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.v1.engine import (
EEPNotificationType,
ReconfigureDistributedRequest,
ReconfigureRankType,
)
from vllm.v1.engine.core import DPEngineCoreProc
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
WorkerType = Literal["existing", "new", "removing"]
class ScaleUpExistingEngineState(enum.IntEnum):
WAIT_NEW_CORE_ENGINES_INIT = 0
CREATE_STANDBY_GROUPS = 1
TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY_SIZE = 5
SWITCH_AND_PREPARE = 6
EPLB_RESHUFFLE = 7
COMPLETE = 8
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class ScaleDownRemainingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
SWITCH_AND_PREPARE = 2
COMPLETE = 3
class ScaleDownRemovingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
in the first stage of our two-staged
TCPStore based barrier to synchronize the
execution of all engines in the DP group.
"""
class ElasticEPScalingState:
def __init__(
self,
model_executor: "Executor",
engine_core: "DPEngineCoreProc",
vllm_config: "VllmConfig",
new_parallel_config: ParallelConfig,
worker_type: WorkerType,
scale_type: Literal["scale_up", "scale_down"],
reconfig_request: ReconfigureDistributedRequest | None = None,
):
self.model_executor_ref = weakref.ref(model_executor)
self.engine_core_ref = weakref.ref(engine_core)
self.vllm_config = vllm_config
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
else:
self.state = (
ScaleDownRemovingEngineState.PREPARE
if worker_type == "removing"
else ScaleDownRemainingEngineState.PREPARE
)
@property
def model_executor(self) -> "Executor":
model_executor = self.model_executor_ref()
if model_executor is None:
raise RuntimeError("Model executor has been garbage collected")
return model_executor
@property
def engine_core(self) -> "DPEngineCoreProc":
engine_core = self.engine_core_ref()
if engine_core is None:
raise RuntimeError("Engine core has been garbage collected")
return engine_core
def progress(self) -> bool:
if self.scale_type == "scale_up":
return (
self._progress_new_engine()
if self.worker_type == "new"
else self._progress_existing_engine()
)
return (
self._progress_removing_engine()
if self.worker_type == "removing"
else self._progress_remaining_engine()
)
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
arrival_key = f"arrival_{barrier_id}_{group_rank}"
dp_store.set(arrival_key, b"1")
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < group_size:
if (
timeout is not None
and time.time() - start_time > timeout.total_seconds()
):
raise _BarrierTimeoutError(
f"Barrier timed out after {timeout.total_seconds()} seconds"
)
for i in range(group_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
present = dp_store.check([key])
if present:
processes_arrived.add(i)
if len(processes_arrived) < group_size:
sched_yield()
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
"""
Execute a two-staged barrier to synchronize all engines in the DP group.
Some DP EngineCores may receive the reconfiguration notifications
later than others, and already proceed to engine step (model forward)
in the busy loop.
In this case, EngineCores that already proceed to reconfiguration
should skip reconfiguration and execute model forward for one more
step, so in the next step, all EngineCores will be synchronized.
We use a two-staged barrier to achieve this. The first time each
EngineCore executes the barrier, if a timeout is reached before the
barrier completes, that means some EngineCores have already entered
engine step. The EngineCores that timed out will then proceed to
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
barrier_id = f"eep_barrier_{barrier_name}"
sync_key = f"{barrier_id}_sync"
# TODO(yongji): figure out appropriate timeout for the barrier
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
try:
self._execute_tcp_store_barrier(
dp_store, group_rank, group_size, barrier_id, timeout=timeout
)
torch.distributed.barrier(dp_group)
if group_rank == 0:
dp_store.delete_key(sync_key)
for i in range(group_size):
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
return True
except _BarrierTimeoutError as e:
if timeout is None:
raise RuntimeError("Unexpected timeout encountered") from e
dp_store.compare_set(sync_key, "", b"1")
return False
def _progress_existing_engine(self) -> bool:
state = self.state
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all existing workers to receive the request
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="create_standby_groups"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._create_standby_groups()
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
return True
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
self._transfer_expert_mapping()
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
return True
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
return False
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="transfer_weights"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._transfer_weights()
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
return True
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
self._sync_kv_cache_memory_size()
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
return True
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
if self.new_dp_group.rank() == 0:
self.new_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle()
self.state = ScaleUpExistingEngineState.COMPLETE
self._update_parallel_config()
return True
else:
assert self.state == ScaleUpExistingEngineState.COMPLETE
return True
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group,
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
assert self.new_dp_group.rank() > 0
self._eplb_reshuffle()
self.state = ScaleUpNewEngineState.COMPLETE
return True
else:
assert self.state == ScaleUpNewEngineState.COMPLETE
return True
def _progress_remaining_engine(self) -> bool:
state = self.state
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True
else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True
def _progress_removing_engine(self) -> bool:
state = self.state
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._switch_and_remove()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
elif (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
and self.state
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
def is_complete(self) -> bool:
if self.scale_type == "scale_up":
return (
self.state == ScaleUpNewEngineState.COMPLETE
if self.worker_type == "new"
else self.state == ScaleUpExistingEngineState.COMPLETE
)
return (
self.state == ScaleDownRemovingEngineState.COMPLETE
if self.worker_type == "removing"
else self.state == ScaleDownRemainingEngineState.COMPLETE
)
def _create_standby_groups(self):
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
def _switch_and_prepare(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_prepare",)
)
old_dp_group = self.old_dp_group
stateless_destroy_torch_distributed_process_group(old_dp_group)
assert self.new_dp_group is not None
new_dp_group = self.new_dp_group
self.engine_core.dp_group = new_dp_group
self.engine_core.dp_rank = new_dp_group.rank()
self.engine_core.dp_store = self.new_dp_store
engines_running = int(self.engine_core.engines_running)
current_wave = self.engine_core.current_wave
step_counter = self.engine_core.step_counter
tensor = torch.tensor(
[engines_running, current_wave, step_counter],
dtype=torch.int32,
device="cpu",
)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
if new_dp_group.rank() == 0:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.RECONFIGURE_FINISHED
)
logger.info("[Elastic EP] Switched to new setup")
def _eplb_reshuffle(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
)
assert self.new_dp_group is not None
if self.new_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _switch_and_remove(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_remove",)
)
def _update_parallel_config(self):
assert self.reconfig_request is not None
reconfig_request = self.reconfig_request
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result