# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy import gc import weakref from collections.abc import Iterable, Sequence import torch import torch.nn as nn import torch.nn.functional as F from torch.distributed import P2POp from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.wrapper import reset_compile_wrapper from vllm.config import ( CompilationMode, set_current_vllm_config, ) from vllm.distributed import ( get_dp_group, get_ep_group, get_pcp_group, get_tp_group, ) from vllm.distributed.elastic_ep.standby_state import ( create_standby_groups, get_standby_dp_group, get_standby_ep_group, pop_standby_groups, ) from vllm.distributed.parallel_state import ( _replace_active_groups, prepare_communication_buffer_for_model, ) from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.workspace import lock_workspace, unlock_workspace logger = init_logger(__name__) def batch_transfer_weights( model: nn.Module, is_sender: bool, peer_rank: int, dp_group: StatelessGroupCoordinator, expert_weights: Sequence[Iterable[torch.Tensor]], ) -> None: device_comm = dp_group.device_communicator if device_comm is None: raise ValueError("No device communicator found") expert_weights_set = set() for weight_group in expert_weights: for weight in weight_group: expert_weights_set.add(weight.data_ptr()) state_dict = model.state_dict() all_params = [] for name, param in state_dict.items(): if name.endswith("expert_map"): continue if param.data_ptr() not in expert_weights_set: all_params.append(param.data) assert len(all_params) > 0 p2p_ops = [] for param in all_params: op = object.__new__(P2POp) if is_sender: op.op = torch.distributed.isend op.tensor = param else: op.op = torch.distributed.irecv op.tensor = param op.group_peer = peer_rank p2p_ops.append(op) device_comm.batch_isend_irecv(p2p_ops) def broadcast_expert_mapping( physical_to_logical: torch.Tensor | None, num_local_physical_experts: int | None, num_logical_experts: int | None, dp_group: StatelessGroupCoordinator, device: torch.device, src_rank: int = 0, ) -> tuple[torch.Tensor, int, int]: if dp_group.rank_in_group == src_rank: assert physical_to_logical is not None assert num_local_physical_experts is not None assert num_logical_experts is not None assert physical_to_logical.dtype == torch.int64 shape_tensor = torch.tensor( list(physical_to_logical.shape), dtype=torch.int64, device="cpu" ) metadata_tensor = torch.tensor( [num_local_physical_experts, num_logical_experts], dtype=torch.int64, device="cpu", ) else: shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu") metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu") shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank) metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank) if dp_group.rank_in_group != src_rank: assert device is not None physical_to_logical = torch.empty( tuple(shape_tensor.tolist()), dtype=torch.int64, device=device, ) assert physical_to_logical is not None physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank) num_local_physical_experts = int(metadata_tensor[0].item()) num_logical_experts = int(metadata_tensor[1].item()) return physical_to_logical, num_local_physical_experts, num_logical_experts class ElasticEPScalingExecutor: def __init__(self, worker): self.worker_ref = weakref.ref(worker) self.reconfig_request = None @property def worker(self): worker = self.worker_ref() if worker is None: raise RuntimeError("Worker has been garbage collected") return worker def execute(self, execute_method: str, *args, **kwargs): method = getattr(self, execute_method, None) if method is None: raise ValueError(f"Unknown execute method: {execute_method}") return method(*args, **kwargs) def create_standby_groups( self, reconfig_request: ReconfigureDistributedRequest ) -> None: self.reconfig_request = reconfig_request new_dp_size = reconfig_request.new_data_parallel_size world_size = self.worker.vllm_config.parallel_config.world_size new_world_size_across_dp = world_size * new_dp_size updated_config = copy.copy(self.worker.vllm_config) updated_config.parallel_config = copy.deepcopy( self.worker.vllm_config.parallel_config ) updated_config.parallel_config.data_parallel_size = new_dp_size with set_current_vllm_config(updated_config): create_standby_groups( new_dp_size=new_dp_size, new_world_size_across_dp=new_world_size_across_dp, master_ip=reconfig_request.new_data_parallel_master_ip, world_group_ports=reconfig_request.new_stateless_world_group_port_list, dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list, ) self.worker.model_runner.eep_eplb_suppressed = True standby_ep_group = get_standby_ep_group() assert standby_ep_group is not None if standby_ep_group.rank == 0: logger.info("[Elastic EP] EPLB disabled during elastic scaling transition") def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: standby_dp_group = get_standby_dp_group() assert standby_dp_group is not None # Broadcast old_dp_size to all workers in standby group if standby_dp_group.rank_in_group < old_dp_size: old_dp_size_tensor = torch.tensor( [old_dp_size], dtype=torch.int64, device="cpu" ) else: old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast( old_dp_size_tensor, 0 ) num_new_workers = new_dp_size - old_dp_size dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank # Sender-receiver pairing: the first new_workers % old_dp_size # senders get (k+1) contiguous receivers, the rest get k # receivers. num_dst_per_sender = num_new_workers // old_dp_size remainder = num_new_workers % old_dp_size if dp_rank < remainder: recv_begin = dp_rank * (num_dst_per_sender + 1) recv_end = recv_begin + num_dst_per_sender + 1 else: recv_begin = ( remainder * (num_dst_per_sender + 1) + (dp_rank - remainder) * num_dst_per_sender ) recv_end = recv_begin + num_dst_per_sender ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) model = self.worker.model_runner.get_model() for new_worker_rank in sorted(ranks_to_send): batch_transfer_weights( model=model, is_sender=True, peer_rank=new_worker_rank, dp_group=standby_dp_group, expert_weights=model.expert_weights, ) torch.cuda.synchronize() def broadcast_expert_mapping(self) -> None: standby_dp_group = get_standby_dp_group() assert standby_dp_group is not None model_config = self.worker.model_runner.model_config eplb_state = self.worker.model_runner.eplb_state assert eplb_state is not None eplb_model_state = eplb_state.model_states[model_config.compute_hash()] physical_to_logical = eplb_model_state.physical_to_logical_map num_physical_experts = physical_to_logical.shape[1] num_local_physical_experts = num_physical_experts // get_ep_group().world_size num_logical_experts = eplb_model_state.logical_replica_count.shape[1] broadcast_expert_mapping( physical_to_logical=physical_to_logical, num_local_physical_experts=num_local_physical_experts, num_logical_experts=num_logical_experts, dp_group=standby_dp_group, src_rank=0, device=self.worker.device, ) def switch_and_remove(self) -> None: _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None) def switch_and_prepare(self) -> None: old_dp_size = get_dp_group().world_size old_ep_size = get_ep_group().world_size _replace_active_groups(**pop_standby_groups()) parallel_config = self.worker.vllm_config.parallel_config reconfig_request = self.reconfig_request assert reconfig_request is not None new_dp_size = reconfig_request.new_data_parallel_size new_ep_size = get_ep_group().world_size parallel_config.data_parallel_size = new_dp_size if ( reconfig_request.new_data_parallel_rank != ReconfigureRankType.KEEP_CURRENT_RANK ): parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank if ( reconfig_request.new_data_parallel_rank_local != ReconfigureRankType.KEEP_CURRENT_RANK ): parallel_config.data_parallel_rank_local = ( reconfig_request.new_data_parallel_rank_local ) parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip ) parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port ) # Reconfigure MoE modules with new EP size moe_modules = [ module for module in self.worker.model_runner.model.modules() if ( module.__class__.__name__ == "FusedMoE" or module.__class__.__name__ == "SharedFusedMoE" ) ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all( module.moe_config.num_local_experts == num_local_experts for module in moe_modules ), "All MoE modules must have the same number of experts" for module in moe_modules: module.moe_config.num_experts = num_local_experts * new_ep_size module.global_num_experts = module.moe_config.num_experts tp_size = get_tp_group().world_size is_sequence_parallel = parallel_config.use_sequence_parallel_moe sp_size = tp_size if is_sequence_parallel else 1 module.moe_parallel_config = FusedMoEParallelConfig.make( tp_size_=tp_size, pcp_size_=get_pcp_group().world_size, dp_size_=get_dp_group().world_size, sp_size_=sp_size, vllm_parallel_config=parallel_config, ) module.moe_config.moe_parallel_config = module.moe_parallel_config # Update EPLB state eplb_state = self.worker.model_runner.eplb_state assert eplb_state is not None model_config = self.worker.model_runner.model_config eplb_model_state = eplb_state.model_states[model_config.compute_hash()] num_physical_experts = num_local_experts * new_ep_size num_logical_experts = eplb_model_state.logical_replica_count.shape[1] parallel_config.eplb_config.num_redundant_experts = ( num_physical_experts - num_logical_experts ) old_physical_to_logical = eplb_model_state.physical_to_logical_map num_moe_layers = old_physical_to_logical.shape[0] num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size if new_dp_size > old_dp_size: expanded_physical_to_logical = torch.full( (num_moe_layers, num_local_experts * new_ep_size), -1, dtype=old_physical_to_logical.dtype, device=old_physical_to_logical.device, ) expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = ( old_physical_to_logical ) eplb_model_state.physical_to_logical_map = expanded_physical_to_logical old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1] pad_size = num_physical_experts - old_num_physical_experts if new_dp_size > old_dp_size: assert pad_size > 0 expanded_expert_load_pass = F.pad( eplb_model_state.expert_load_pass, (0, pad_size), value=0 ) expanded_expert_load_window = F.pad( eplb_model_state.expert_load_window, (0, pad_size), value=0 ) eplb_model_state.expert_load_pass = expanded_expert_load_pass eplb_model_state.expert_load_window = expanded_expert_load_window eplb_state.num_valid_physical_experts = old_num_physical_experts else: assert pad_size < 0 eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[ :, :num_physical_experts ] eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[ :, :, :num_physical_experts ] eplb_state.num_valid_physical_experts = num_physical_experts model = self.worker.model_runner.get_model() model.expert_weights = [] with set_current_vllm_config(self.worker.vllm_config): model.set_eplb_state( eplb_model_state.expert_load_pass, eplb_model_state.logical_to_physical_map, eplb_model_state.logical_replica_count, ) model.update_physical_experts_metadata( num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_experts, ) # Force re-creation of the modular kernel (and all2all manager) # for the new EP size by resetting quant_method to base for module in moe_modules: if hasattr(module.quant_method, "old_quant_method"): module.quant_method = module.quant_method.old_quant_method module.runner = module._init_runner() prepare_communication_buffer_for_model(self.worker.model_runner.model) if ( self.worker.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE ): # NOTE(yongji): when using stock torch.compile, # torch.compile is triggered during GPUModelRunner's load_model() # TODO(yongji):check do we need to re-trigger torch.compile here? # any changes to the tensor shapes in execution should already # be handled internally by torch.compile. backend = self.worker.vllm_config.compilation_config.init_backend( self.worker.vllm_config ) compilation_counter.stock_torch_compile_count += 1 self.worker.model_runner.model.compile(fullgraph=True, backend=backend) # release all previously captured CUDA graphs if isinstance(self.worker.model_runner.model, CUDAGraphWrapper): wrapper = self.worker.model_runner.model wrapper.concrete_cudagraph_entries = {} elif isinstance(self.worker.model_runner.model, UBatchWrapper): raise RuntimeError("DBO is not yet supported in elastic EP") multi_block_table = self.worker.model_runner.input_batch.block_table saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = [] for bt in multi_block_table.block_tables: saved_block_tables.append( (bt.block_table.gpu.clone(), bt.block_table.cpu.clone()) ) multi_block_table.clear() # reset the compile wrapper torch.compiler.reset() with set_current_vllm_config(self.worker.vllm_config): reset_compile_wrapper(self.worker.model_runner.get_model()) gc.collect() torch.cuda.synchronize() torch.cuda.empty_cache() unlock_workspace() self.worker.compile_or_warm_up_model() lock_workspace() for bt, (saved_gpu, saved_cpu) in zip( multi_block_table.block_tables, saved_block_tables ): bt.block_table.gpu.copy_(saved_gpu) bt.block_table.cpu.copy_(saved_cpu) def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None: if get_ep_group().rank == 0: logger.info("[Elastic EP] Starting expert resharding...") eplb_state = self.worker.model_runner.eplb_state assert eplb_state is not None model_config = self.worker.model_runner.model_config eplb_model_state = eplb_state.model_states[model_config.compute_hash()] is_async_enabled = eplb_state.is_async eplb_state.is_async = False if new_dp_size is None: eplb_state.rearrange() else: # scale down parallel_config = self.worker.vllm_config.parallel_config tp_size = parallel_config.tensor_parallel_size old_ep_size = parallel_config.data_parallel_size * tp_size new_ep_size = new_dp_size * tp_size rank_mapping = { old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 for old_ep_rank in range(old_ep_size) } eplb_state.rearrange(rank_mapping=rank_mapping) # NOTE(yongji): check whether we need to synchronize here torch.cuda.synchronize() # reset expert_rearrangement_step to ensure all ranks are synchronized eplb_state.expert_rearrangement_step = 0 eplb_state.num_valid_physical_experts = ( eplb_model_state.physical_to_logical_map.shape[1] ) eplb_state.is_async = is_async_enabled self.worker.model_runner.eep_eplb_suppressed = False if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed") def receive_weights(self) -> None: dp_group = get_dp_group() assert isinstance(dp_group, StatelessGroupCoordinator) new_dp_size = dp_group.world_size dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank # Receive old_dp_size broadcasted during transfer_weights old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) old_dp_size = int(old_dp_size_tensor[0].item()) # Calculate which existing worker will send to this new worker num_new_workers = new_dp_size - old_dp_size new_worker_idx = dp_rank - old_dp_size num_dst_per_sender = num_new_workers // old_dp_size remainder = num_new_workers % old_dp_size if new_worker_idx < remainder * (num_dst_per_sender + 1): sender_rank = new_worker_idx // (num_dst_per_sender + 1) else: sender_rank = ( remainder + (new_worker_idx - remainder * (num_dst_per_sender + 1)) // num_dst_per_sender ) model = self.worker.model_runner.get_model() batch_transfer_weights( model=model, is_sender=False, peer_rank=sender_rank, dp_group=dp_group, expert_weights=model.expert_weights, ) torch.cuda.synchronize() def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: dp_group = get_dp_group() assert isinstance(dp_group, StatelessGroupCoordinator) physical_to_logical, num_local_physical_experts, num_logical_experts = ( broadcast_expert_mapping( physical_to_logical=None, num_local_physical_experts=None, num_logical_experts=None, dp_group=dp_group, src_rank=0, device=self.worker.device, ) ) num_moe_layers = physical_to_logical.shape[0] new_dp_size = get_dp_group().world_size tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size new_ep_size = new_dp_size * tp_size expanded_physical_to_logical = torch.full( (num_moe_layers, num_local_physical_experts * new_ep_size), -1, dtype=physical_to_logical.dtype, device=physical_to_logical.device, ) old_num_physical_experts = physical_to_logical.shape[1] expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical return ( expanded_physical_to_logical, num_logical_experts, old_num_physical_experts, ) def prepare_new_worker(self) -> None: with set_current_vllm_config(self.worker.vllm_config): prepare_communication_buffer_for_model(self.worker.model_runner.get_model())