Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
0
vllm/distributed/elastic_ep/__init__.py
Normal file
0
vllm/distributed/elastic_ep/__init__.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import gc
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import P2POp
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.wrapper import reset_compile_wrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.elastic_ep.standby_state import (
|
||||
create_standby_groups,
|
||||
get_standby_dp_group,
|
||||
get_standby_ep_group,
|
||||
pop_standby_groups,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
_replace_active_groups,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def batch_transfer_weights(
|
||||
model: nn.Module,
|
||||
is_sender: bool,
|
||||
peer_rank: int,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
) -> None:
|
||||
device_comm = dp_group.device_communicator
|
||||
if device_comm is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
expert_weights_set = set()
|
||||
for weight_group in expert_weights:
|
||||
for weight in weight_group:
|
||||
expert_weights_set.add(weight.data_ptr())
|
||||
|
||||
state_dict = model.state_dict()
|
||||
all_params = []
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith("expert_map"):
|
||||
continue
|
||||
if param.data_ptr() not in expert_weights_set:
|
||||
all_params.append(param.data)
|
||||
|
||||
assert len(all_params) > 0
|
||||
p2p_ops = []
|
||||
for param in all_params:
|
||||
op = object.__new__(P2POp)
|
||||
if is_sender:
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = param
|
||||
else:
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = param
|
||||
op.group_peer = peer_rank
|
||||
p2p_ops.append(op)
|
||||
device_comm.batch_isend_irecv(p2p_ops)
|
||||
|
||||
|
||||
def broadcast_expert_mapping(
|
||||
physical_to_logical: torch.Tensor | None,
|
||||
num_local_physical_experts: int | None,
|
||||
num_logical_experts: int | None,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
device: torch.device,
|
||||
src_rank: int = 0,
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
if dp_group.rank_in_group == src_rank:
|
||||
assert physical_to_logical is not None
|
||||
assert num_local_physical_experts is not None
|
||||
assert num_logical_experts is not None
|
||||
assert physical_to_logical.dtype == torch.int64
|
||||
shape_tensor = torch.tensor(
|
||||
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
|
||||
)
|
||||
metadata_tensor = torch.tensor(
|
||||
[num_local_physical_experts, num_logical_experts],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
|
||||
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
|
||||
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
|
||||
|
||||
if dp_group.rank_in_group != src_rank:
|
||||
assert device is not None
|
||||
physical_to_logical = torch.empty(
|
||||
tuple(shape_tensor.tolist()),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert physical_to_logical is not None
|
||||
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
|
||||
num_local_physical_experts = int(metadata_tensor[0].item())
|
||||
num_logical_experts = int(metadata_tensor[1].item())
|
||||
|
||||
return physical_to_logical, num_local_physical_experts, num_logical_experts
|
||||
|
||||
|
||||
class ElasticEPScalingExecutor:
|
||||
def __init__(self, worker):
|
||||
self.worker_ref = weakref.ref(worker)
|
||||
self.reconfig_request = None
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
worker = self.worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def execute(self, execute_method: str, *args, **kwargs):
|
||||
method = getattr(self, execute_method, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Unknown execute method: {execute_method}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def create_standby_groups(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.reconfig_request = reconfig_request
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
world_size = self.worker.vllm_config.parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_dp_size
|
||||
updated_config = copy.copy(self.worker.vllm_config)
|
||||
updated_config.parallel_config = copy.deepcopy(
|
||||
self.worker.vllm_config.parallel_config
|
||||
)
|
||||
updated_config.parallel_config.data_parallel_size = new_dp_size
|
||||
with set_current_vllm_config(updated_config):
|
||||
create_standby_groups(
|
||||
new_dp_size=new_dp_size,
|
||||
new_world_size_across_dp=new_world_size_across_dp,
|
||||
master_ip=reconfig_request.new_data_parallel_master_ip,
|
||||
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
|
||||
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
|
||||
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
|
||||
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
assert standby_ep_group is not None
|
||||
if standby_ep_group.rank == 0:
|
||||
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
|
||||
|
||||
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
# Broadcast old_dp_size to all workers in standby group
|
||||
if standby_dp_group.rank_in_group < old_dp_size:
|
||||
old_dp_size_tensor = torch.tensor(
|
||||
[old_dp_size], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
else:
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
|
||||
old_dp_size_tensor, 0
|
||||
)
|
||||
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Sender-receiver pairing: the first new_workers % old_dp_size
|
||||
# senders get (k+1) contiguous receivers, the rest get k
|
||||
# receivers.
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if dp_rank < remainder:
|
||||
recv_begin = dp_rank * (num_dst_per_sender + 1)
|
||||
recv_end = recv_begin + num_dst_per_sender + 1
|
||||
else:
|
||||
recv_begin = (
|
||||
remainder * (num_dst_per_sender + 1)
|
||||
+ (dp_rank - remainder) * num_dst_per_sender
|
||||
)
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
|
||||
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
for new_worker_rank in sorted(ranks_to_send):
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=True,
|
||||
peer_rank=new_worker_rank,
|
||||
dp_group=standby_dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def broadcast_expert_mapping(self) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_physical_experts = physical_to_logical.shape[1]
|
||||
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=physical_to_logical,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_logical_experts=num_logical_experts,
|
||||
dp_group=standby_dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
|
||||
def switch_and_remove(self) -> None:
|
||||
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
|
||||
|
||||
def switch_and_prepare(self) -> None:
|
||||
old_dp_size = get_dp_group().world_size
|
||||
old_ep_size = get_ep_group().world_size
|
||||
|
||||
_replace_active_groups(**pop_standby_groups())
|
||||
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
reconfig_request = self.reconfig_request
|
||||
assert reconfig_request is not None
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
new_ep_size = get_ep_group().world_size
|
||||
|
||||
parallel_config.data_parallel_size = new_dp_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
# Reconfigure MoE modules with new EP size
|
||||
moe_modules = [
|
||||
module
|
||||
for module in self.worker.model_runner.model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
|
||||
# Update EPLB state
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
|
||||
num_physical_experts = num_local_experts * new_ep_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
old_physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_moe_layers = old_physical_to_logical.shape[0]
|
||||
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
|
||||
if new_dp_size > old_dp_size:
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=old_physical_to_logical.dtype,
|
||||
device=old_physical_to_logical.device,
|
||||
)
|
||||
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
|
||||
old_physical_to_logical
|
||||
)
|
||||
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
|
||||
|
||||
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
|
||||
pad_size = num_physical_experts - old_num_physical_experts
|
||||
if new_dp_size > old_dp_size:
|
||||
assert pad_size > 0
|
||||
expanded_expert_load_pass = F.pad(
|
||||
eplb_model_state.expert_load_pass, (0, pad_size), value=0
|
||||
)
|
||||
expanded_expert_load_window = F.pad(
|
||||
eplb_model_state.expert_load_window, (0, pad_size), value=0
|
||||
)
|
||||
eplb_model_state.expert_load_pass = expanded_expert_load_pass
|
||||
eplb_model_state.expert_load_window = expanded_expert_load_window
|
||||
eplb_state.num_valid_physical_experts = old_num_physical_experts
|
||||
else:
|
||||
assert pad_size < 0
|
||||
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
|
||||
:, :num_physical_experts
|
||||
]
|
||||
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, :num_physical_experts
|
||||
]
|
||||
eplb_state.num_valid_physical_experts = num_physical_experts
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
model.expert_weights = []
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
model.set_eplb_state(
|
||||
eplb_model_state.expert_load_pass,
|
||||
eplb_model_state.logical_to_physical_map,
|
||||
eplb_model_state.logical_replica_count,
|
||||
)
|
||||
model.update_physical_experts_metadata(
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_experts,
|
||||
)
|
||||
# Force re-creation of the modular kernel (and all2all manager)
|
||||
# for the new EP size by resetting quant_method to base
|
||||
for module in moe_modules:
|
||||
if hasattr(module.quant_method, "old_quant_method"):
|
||||
module.quant_method = module.quant_method.old_quant_method
|
||||
module.runner = module._init_runner()
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
if (
|
||||
self.worker.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
# NOTE(yongji): when using stock torch.compile,
|
||||
# torch.compile is triggered during GPUModelRunner's load_model()
|
||||
# TODO(yongji):check do we need to re-trigger torch.compile here?
|
||||
# any changes to the tensor shapes in execution should already
|
||||
# be handled internally by torch.compile.
|
||||
backend = self.worker.vllm_config.compilation_config.init_backend(
|
||||
self.worker.vllm_config
|
||||
)
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
# release all previously captured CUDA graphs
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
multi_block_table = self.worker.model_runner.input_batch.block_table
|
||||
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for bt in multi_block_table.block_tables:
|
||||
saved_block_tables.append(
|
||||
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
|
||||
)
|
||||
multi_block_table.clear()
|
||||
|
||||
# reset the compile wrapper
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
unlock_workspace()
|
||||
self.worker.compile_or_warm_up_model()
|
||||
lock_workspace()
|
||||
|
||||
for bt, (saved_gpu, saved_cpu) in zip(
|
||||
multi_block_table.block_tables, saved_block_tables
|
||||
):
|
||||
bt.block_table.gpu.copy_(saved_gpu)
|
||||
bt.block_table.cpu.copy_(saved_cpu)
|
||||
|
||||
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding...")
|
||||
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
is_async_enabled = eplb_state.is_async
|
||||
eplb_state.is_async = False
|
||||
if new_dp_size is None:
|
||||
eplb_state.rearrange()
|
||||
else:
|
||||
# scale down
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
|
||||
eplb_state.rearrange(rank_mapping=rank_mapping)
|
||||
# NOTE(yongji): check whether we need to synchronize here
|
||||
torch.cuda.synchronize()
|
||||
# reset expert_rearrangement_step to ensure all ranks are synchronized
|
||||
eplb_state.expert_rearrangement_step = 0
|
||||
eplb_state.num_valid_physical_experts = (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
)
|
||||
eplb_state.is_async = is_async_enabled
|
||||
self.worker.model_runner.eep_eplb_suppressed = False
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed")
|
||||
|
||||
def receive_weights(self) -> None:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
new_dp_size = dp_group.world_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Receive old_dp_size broadcasted during transfer_weights
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
|
||||
old_dp_size = int(old_dp_size_tensor[0].item())
|
||||
|
||||
# Calculate which existing worker will send to this new worker
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
new_worker_idx = dp_rank - old_dp_size
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if new_worker_idx < remainder * (num_dst_per_sender + 1):
|
||||
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
|
||||
else:
|
||||
sender_rank = (
|
||||
remainder
|
||||
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
|
||||
// num_dst_per_sender
|
||||
)
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=False,
|
||||
peer_rank=sender_rank,
|
||||
dp_group=dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
physical_to_logical, num_local_physical_experts, num_logical_experts = (
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=None,
|
||||
num_local_physical_experts=None,
|
||||
num_logical_experts=None,
|
||||
dp_group=dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
)
|
||||
num_moe_layers = physical_to_logical.shape[0]
|
||||
new_dp_size = get_dp_group().world_size
|
||||
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_physical_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=physical_to_logical.dtype,
|
||||
device=physical_to_logical.device,
|
||||
)
|
||||
old_num_physical_experts = physical_to_logical.shape[1]
|
||||
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
|
||||
return (
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
)
|
||||
|
||||
def prepare_new_worker(self) -> None:
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
|
||||
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
@@ -0,0 +1,563 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import time
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import (
|
||||
sched_yield,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import (
|
||||
EEPNotificationType,
|
||||
ReconfigureDistributedRequest,
|
||||
ReconfigureRankType,
|
||||
)
|
||||
from vllm.v1.engine.core import DPEngineCoreProc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WorkerType = Literal["existing", "new", "removing"]
|
||||
|
||||
|
||||
class ScaleUpExistingEngineState(enum.IntEnum):
|
||||
WAIT_NEW_CORE_ENGINES_INIT = 0
|
||||
CREATE_STANDBY_GROUPS = 1
|
||||
TRANSFER_EXPERT_MAPPING = 2
|
||||
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
|
||||
TRANSFER_WEIGHTS = 4
|
||||
SYNC_KV_CACHE_MEMORY_SIZE = 5
|
||||
SWITCH_AND_PREPARE = 6
|
||||
EPLB_RESHUFFLE = 7
|
||||
COMPLETE = 8
|
||||
|
||||
|
||||
class ScaleUpNewEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class ScaleDownRemainingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
SWITCH_AND_PREPARE = 2
|
||||
COMPLETE = 3
|
||||
|
||||
|
||||
class ScaleDownRemovingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class _BarrierTimeoutError(RuntimeError):
|
||||
"""
|
||||
Exception raised for timeout
|
||||
in the first stage of our two-staged
|
||||
TCPStore based barrier to synchronize the
|
||||
execution of all engines in the DP group.
|
||||
"""
|
||||
|
||||
|
||||
class ElasticEPScalingState:
|
||||
def __init__(
|
||||
self,
|
||||
model_executor: "Executor",
|
||||
engine_core: "DPEngineCoreProc",
|
||||
vllm_config: "VllmConfig",
|
||||
new_parallel_config: ParallelConfig,
|
||||
worker_type: WorkerType,
|
||||
scale_type: Literal["scale_up", "scale_down"],
|
||||
reconfig_request: ReconfigureDistributedRequest | None = None,
|
||||
):
|
||||
self.model_executor_ref = weakref.ref(model_executor)
|
||||
self.engine_core_ref = weakref.ref(engine_core)
|
||||
self.vllm_config = vllm_config
|
||||
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
|
||||
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
|
||||
self.new_parallel_config: ParallelConfig = new_parallel_config
|
||||
self.new_dp_group: torch.distributed.ProcessGroup | None = (
|
||||
self.engine_core.dp_group if worker_type == "new" else None
|
||||
)
|
||||
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
|
||||
self.worker_type = worker_type
|
||||
self.scale_type = scale_type
|
||||
self.reconfig_request = reconfig_request
|
||||
|
||||
if scale_type == "scale_up":
|
||||
self.state = (
|
||||
ScaleUpNewEngineState.PREPARE
|
||||
if worker_type == "new"
|
||||
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
)
|
||||
else:
|
||||
self.state = (
|
||||
ScaleDownRemovingEngineState.PREPARE
|
||||
if worker_type == "removing"
|
||||
else ScaleDownRemainingEngineState.PREPARE
|
||||
)
|
||||
|
||||
@property
|
||||
def model_executor(self) -> "Executor":
|
||||
model_executor = self.model_executor_ref()
|
||||
if model_executor is None:
|
||||
raise RuntimeError("Model executor has been garbage collected")
|
||||
return model_executor
|
||||
|
||||
@property
|
||||
def engine_core(self) -> "DPEngineCoreProc":
|
||||
engine_core = self.engine_core_ref()
|
||||
if engine_core is None:
|
||||
raise RuntimeError("Engine core has been garbage collected")
|
||||
return engine_core
|
||||
|
||||
def progress(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self._progress_new_engine()
|
||||
if self.worker_type == "new"
|
||||
else self._progress_existing_engine()
|
||||
)
|
||||
return (
|
||||
self._progress_removing_engine()
|
||||
if self.worker_type == "removing"
|
||||
else self._progress_remaining_engine()
|
||||
)
|
||||
|
||||
def _execute_tcp_store_barrier(
|
||||
self, dp_store, group_rank, group_size, barrier_id, timeout=None
|
||||
):
|
||||
arrival_key = f"arrival_{barrier_id}_{group_rank}"
|
||||
dp_store.set(arrival_key, b"1")
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < group_size:
|
||||
if (
|
||||
timeout is not None
|
||||
and time.time() - start_time > timeout.total_seconds()
|
||||
):
|
||||
raise _BarrierTimeoutError(
|
||||
f"Barrier timed out after {timeout.total_seconds()} seconds"
|
||||
)
|
||||
|
||||
for i in range(group_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
present = dp_store.check([key])
|
||||
if present:
|
||||
processes_arrived.add(i)
|
||||
|
||||
if len(processes_arrived) < group_size:
|
||||
sched_yield()
|
||||
|
||||
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
|
||||
"""
|
||||
Execute a two-staged barrier to synchronize all engines in the DP group.
|
||||
|
||||
Some DP EngineCores may receive the reconfiguration notifications
|
||||
later than others, and already proceed to engine step (model forward)
|
||||
in the busy loop.
|
||||
In this case, EngineCores that already proceed to reconfiguration
|
||||
should skip reconfiguration and execute model forward for one more
|
||||
step, so in the next step, all EngineCores will be synchronized.
|
||||
We use a two-staged barrier to achieve this. The first time each
|
||||
EngineCore executes the barrier, if a timeout is reached before the
|
||||
barrier completes, that means some EngineCores have already entered
|
||||
engine step. The EngineCores that timed out will then proceed to
|
||||
engine step, and will synchronize with the other EngineCores in the
|
||||
next step with a barrier without timeout.
|
||||
"""
|
||||
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
|
||||
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
|
||||
assert dp_group is not None
|
||||
|
||||
group_rank = dp_group.rank()
|
||||
group_size = dp_group.size()
|
||||
barrier_id = f"eep_barrier_{barrier_name}"
|
||||
sync_key = f"{barrier_id}_sync"
|
||||
|
||||
# TODO(yongji): figure out appropriate timeout for the barrier
|
||||
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
|
||||
|
||||
try:
|
||||
self._execute_tcp_store_barrier(
|
||||
dp_store, group_rank, group_size, barrier_id, timeout=timeout
|
||||
)
|
||||
torch.distributed.barrier(dp_group)
|
||||
if group_rank == 0:
|
||||
dp_store.delete_key(sync_key)
|
||||
for i in range(group_size):
|
||||
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
return True
|
||||
except _BarrierTimeoutError as e:
|
||||
if timeout is None:
|
||||
raise RuntimeError("Unexpected timeout encountered") from e
|
||||
dp_store.compare_set(sync_key, "", b"1")
|
||||
return False
|
||||
|
||||
def _progress_existing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
|
||||
# NOTE(yongji): wait for all existing workers to receive the request
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="create_standby_groups"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._create_standby_groups()
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
|
||||
self._transfer_expert_mapping()
|
||||
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="transfer_weights"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._transfer_weights()
|
||||
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
|
||||
self._sync_kv_cache_memory_size()
|
||||
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
|
||||
self._switch_and_prepare()
|
||||
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
|
||||
assert self.new_dp_group is not None
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.new_dp_group.rank() == 0:
|
||||
self.new_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpExistingEngineState.COMPLETE
|
||||
self._update_parallel_config()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_new_engine(self) -> bool:
|
||||
state = self.state
|
||||
assert self.new_dp_group is not None
|
||||
|
||||
if state == ScaleUpNewEngineState.PREPARE:
|
||||
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.new_dp_group,
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.new_dp_group.rank() > 0
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_remaining_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemainingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
|
||||
# NOTE(yongji): currently, after EPLB reshuffle
|
||||
# that redistributes experts to remaining workers, workers
|
||||
# to be removed will immediately initiate shutdown;
|
||||
# existing workers can no longer execute forward steps using
|
||||
# the old setup. In the future, we may keep
|
||||
# the removing workers alive a bit longer,
|
||||
# e.g., to drain in-batch requests.
|
||||
self._create_standby_groups()
|
||||
self._switch_and_prepare()
|
||||
self._update_parallel_config()
|
||||
self.state = ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_removing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemovingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.old_dp_group.rank() > 0
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self._switch_and_remove()
|
||||
self.state = ScaleDownRemovingEngineState.COMPLETE
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE
|
||||
)
|
||||
self.engine_core.shutdown()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def handle_notification(self, notification_type: EEPNotificationType):
|
||||
assert self.worker_type != "new"
|
||||
if (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
|
||||
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
|
||||
elif (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
and self.state
|
||||
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self.state == ScaleUpNewEngineState.COMPLETE
|
||||
if self.worker_type == "new"
|
||||
else self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
)
|
||||
return (
|
||||
self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
if self.worker_type == "removing"
|
||||
else self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
)
|
||||
|
||||
def _create_standby_groups(self):
|
||||
self.new_dp_group, self.new_dp_store = (
|
||||
self.new_parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Created standby communication groups")
|
||||
|
||||
def _transfer_weights(self):
|
||||
assert self.reconfig_request is not None
|
||||
old_dp_size = self.old_dp_group.size()
|
||||
new_dp_size = self.reconfig_request.new_data_parallel_size
|
||||
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Transferred weights to new workers")
|
||||
|
||||
def _transfer_expert_mapping(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("broadcast_expert_mapping",)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
|
||||
|
||||
def _sync_kv_cache_memory_size(self):
|
||||
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
|
||||
assert self.new_dp_group is not None
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.new_dp_group,
|
||||
self.engine_core.available_gpu_memory_for_kv_cache,
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
|
||||
|
||||
def _switch_and_prepare(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_prepare",)
|
||||
)
|
||||
old_dp_group = self.old_dp_group
|
||||
stateless_destroy_torch_distributed_process_group(old_dp_group)
|
||||
assert self.new_dp_group is not None
|
||||
new_dp_group = self.new_dp_group
|
||||
self.engine_core.dp_group = new_dp_group
|
||||
self.engine_core.dp_rank = new_dp_group.rank()
|
||||
self.engine_core.dp_store = self.new_dp_store
|
||||
engines_running = int(self.engine_core.engines_running)
|
||||
current_wave = self.engine_core.current_wave
|
||||
step_counter = self.engine_core.step_counter
|
||||
tensor = torch.tensor(
|
||||
[engines_running, current_wave, step_counter],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
if new_dp_group.rank() == 0:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.RECONFIGURE_FINISHED
|
||||
)
|
||||
logger.info("[Elastic EP] Switched to new setup")
|
||||
|
||||
def _eplb_reshuffle(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
|
||||
)
|
||||
assert self.new_dp_group is not None
|
||||
if self.new_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _eplb_reshuffle_before_scale_down(self):
|
||||
assert self.reconfig_request is not None
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute",
|
||||
args=(
|
||||
"perform_eplb_reshuffle",
|
||||
self.reconfig_request.new_data_parallel_size,
|
||||
),
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _switch_and_remove(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_remove",)
|
||||
)
|
||||
|
||||
def _update_parallel_config(self):
|
||||
assert self.reconfig_request is not None
|
||||
reconfig_request = self.reconfig_request
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
reconfig_request.new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = (
|
||||
reconfig_request.new_stateless_dp_group_port_list
|
||||
)
|
||||
parallel_config._stateless_ep_group_port_list = (
|
||||
reconfig_request.new_stateless_ep_group_port_list
|
||||
)
|
||||
parallel_config._stateless_eplb_group_port_list = (
|
||||
reconfig_request.new_stateless_eplb_group_port_list
|
||||
)
|
||||
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
_init_stateless_group,
|
||||
_node_count,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_WORLD_NODE_COUNT: int | None = None
|
||||
_STANDBY_DP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_DP
|
||||
|
||||
|
||||
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EP
|
||||
|
||||
|
||||
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EPLB
|
||||
|
||||
|
||||
def get_standby_world_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_WORLD
|
||||
|
||||
|
||||
def create_standby_groups(
|
||||
new_dp_size: int,
|
||||
new_world_size_across_dp: int,
|
||||
master_ip: str,
|
||||
world_group_ports: list[list[int]],
|
||||
dp_group_ports: list[list[int]],
|
||||
ep_group_ports: list[list[int]],
|
||||
eplb_group_ports: list[list[int]] | None = None,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
|
||||
world_group = get_world_group()
|
||||
assert isinstance(world_group, StatelessGroupCoordinator)
|
||||
backend = backend or world_group.backend
|
||||
|
||||
standby_world_ranks = [list(range(new_world_size_across_dp))]
|
||||
_STANDBY_WORLD = _init_stateless_group(
|
||||
standby_world_ranks,
|
||||
"world",
|
||||
world_group_ports,
|
||||
master_ip,
|
||||
backend,
|
||||
use_device_communicator=False,
|
||||
)
|
||||
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
|
||||
|
||||
tp_size = get_tp_group().world_size
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
all_ranks = torch.arange(new_world_size_across_dp).reshape(
|
||||
-1, new_dp_size, pp_size, tp_size
|
||||
)
|
||||
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
|
||||
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
|
||||
_STANDBY_DP = _init_stateless_group(
|
||||
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
standby_ep_ranks = (
|
||||
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
|
||||
)
|
||||
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
|
||||
_STANDBY_EP = _init_stateless_group(
|
||||
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
if eplb_group_ports is not None:
|
||||
_STANDBY_EPLB = _init_stateless_group(
|
||||
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
|
||||
def pop_standby_groups() -> dict:
|
||||
"""Return all standby groups and clear the standby state."""
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
result = dict(
|
||||
world=_STANDBY_WORLD,
|
||||
dp=_STANDBY_DP,
|
||||
ep=_STANDBY_EP,
|
||||
eplb=_STANDBY_EPLB,
|
||||
node_count=_STANDBY_WORLD_NODE_COUNT,
|
||||
)
|
||||
_STANDBY_WORLD = None
|
||||
_STANDBY_WORLD_NODE_COUNT = None
|
||||
_STANDBY_DP = None
|
||||
_STANDBY_EP = None
|
||||
_STANDBY_EPLB = None
|
||||
return result
|
||||
Reference in New Issue
Block a user