v1.0
This commit is contained in:
8
distributed/eplb/__init__.py
Normal file
8
distributed/eplb/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB).
|
||||
"""
|
||||
|
||||
from .eplb_state import *
|
||||
from .rebalance_algo import *
|
||||
BIN
distributed/eplb/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
distributed/eplb/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
distributed/eplb/__pycache__/eplb_state.cpython-312.pyc
Normal file
BIN
distributed/eplb/__pycache__/eplb_state.cpython-312.pyc
Normal file
Binary file not shown.
BIN
distributed/eplb/__pycache__/rebalance_algo.cpython-312.pyc
Normal file
BIN
distributed/eplb/__pycache__/rebalance_algo.cpython-312.pyc
Normal file
Binary file not shown.
BIN
distributed/eplb/__pycache__/rebalance_execute.cpython-312.pyc
Normal file
BIN
distributed/eplb/__pycache__/rebalance_execute.cpython-312.pyc
Normal file
Binary file not shown.
837
distributed/eplb/eplb_state.py
Normal file
837
distributed/eplb/eplb_state.py
Normal file
@@ -0,0 +1,837 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) metrics and states.
|
||||
|
||||
# Glossary
|
||||
|
||||
- **Logical Expert**: An expert that is part of the model's logical structure.
|
||||
It holds a set of weights and is replicated across multiple physical
|
||||
experts.
|
||||
- **Redundant Expert**: To achieve load balancing, for some popular logical
|
||||
experts, we create additional copies of the expert weights. During inference,
|
||||
each of these copies can be routed to by the same set of tokens.
|
||||
- **Physical Expert**: An expert that is instantiated on a specific device.
|
||||
It is a replica of a logical expert and can be rearranged across devices.
|
||||
I.e., one logical expert may have multiple sets of weights initialized on
|
||||
different devices, and each of these sets is a physical expert.
|
||||
- **Local Physical Expert**: A physical expert that is instantiated on the
|
||||
current device.
|
||||
|
||||
For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
|
||||
has 256 sets of linear layer weights in the model parameters. If we add 32
|
||||
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
|
||||
total. And when deploying, we'll have 288 sets of linear layer weights for each
|
||||
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
|
||||
physical experts.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, all_reduce
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_ep_group,
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
from .rebalance_algo import rebalance_experts
|
||||
from .rebalance_execute import rearrange_expert_weights_inplace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbModelState:
|
||||
"""EPLB metrics."""
|
||||
|
||||
physical_to_logical_map: torch.Tensor
|
||||
"""
|
||||
Mapping from physical experts to logical experts.
|
||||
|
||||
Shape: (num_moe_layers, num_physical_experts)
|
||||
|
||||
# Example
|
||||
|
||||
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
|
||||
EP ranks, the mapping could look like this:
|
||||
|
||||
```
|
||||
[[0, 1, 2, 3, 0, 1],
|
||||
[0, 2, 0, 1, 0, 3]]
|
||||
```
|
||||
"""
|
||||
logical_to_physical_map: torch.Tensor
|
||||
"""
|
||||
Mapping from logical experts to physical experts.
|
||||
|
||||
This is a sparse matrix, where -1 indicates no mapping.
|
||||
|
||||
Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)
|
||||
|
||||
# Example
|
||||
|
||||
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
|
||||
EP ranks, the mapping could look like this:
|
||||
|
||||
```
|
||||
[[[0, 4, -1],
|
||||
[1, 5, -1],
|
||||
[2, -1, -1],
|
||||
[3, -1, -1]],
|
||||
[[0, 2, 4],
|
||||
[3, -1, -1],
|
||||
[1, -1, -1],
|
||||
[5, -1, -1]]]
|
||||
```
|
||||
"""
|
||||
logical_replica_count: torch.Tensor
|
||||
"""
|
||||
Number of replicas for each logical expert.
|
||||
This is exactly the non-`-1` count in the `logical_to_physical_map`.
|
||||
|
||||
Shape: (num_moe_layers, num_logical_experts)
|
||||
|
||||
# Example
|
||||
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
|
||||
EP ranks, the count could look like this:
|
||||
|
||||
```
|
||||
[[2, 2, 1, 1],
|
||||
[3, 1, 1, 1]]
|
||||
"""
|
||||
|
||||
expert_load_pass: torch.Tensor
|
||||
"""
|
||||
Expert load during this forward pass.
|
||||
We use the token count each expert processes as the load.
|
||||
|
||||
Shape: (num_moe_layers, num_physical_experts)
|
||||
"""
|
||||
expert_load_window: torch.Tensor
|
||||
"""
|
||||
A sliding window of expert load.
|
||||
|
||||
Shape: (window_size, num_moe_layers, num_physical_experts)
|
||||
|
||||
NOTE: The expert_load_view now records load for all physical experts
|
||||
rather than just local experts. This ensures consistent load statistics
|
||||
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
|
||||
The recorded load will be multiplied by dp_size when using naive all-to-all
|
||||
due to each DP rank contributing the same token set to the calculation.
|
||||
See:
|
||||
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
|
||||
"""
|
||||
model_name: str
|
||||
model: MixtureOfExperts
|
||||
|
||||
|
||||
class EplbState:
|
||||
"""
|
||||
EplbState of each expert parallel model. Key is the model config hash.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig, device: torch.device):
|
||||
self.parallel_config = parallel_config
|
||||
self.device = device
|
||||
self.model_states: dict[str, EplbModelState] = {}
|
||||
"""
|
||||
Current step in the sliding window.
|
||||
|
||||
Different from `expert_rearrangement_step`,
|
||||
each EP rank may have its own `expert_load_window_step`.
|
||||
"""
|
||||
self.expert_load_window_step: int = 0
|
||||
"""
|
||||
Size of the expert load sliding window.
|
||||
This is a constant and is taken from the config.
|
||||
"""
|
||||
self.expert_load_window_size: int = 0
|
||||
"""
|
||||
Steps after last rearrangement.
|
||||
Will trigger a rearrangement if it exceeds the threshold.
|
||||
|
||||
NOTE: Keep in mind that all EP ranks need to have the same
|
||||
`expert_rearrangement_step` value to ensure synchronization.
|
||||
Otherwise, the rearrangement will hang at collective
|
||||
communication calls.
|
||||
"""
|
||||
self.expert_rearrangement_step: int = 0
|
||||
"""
|
||||
Interval for expert rearrangement steps.
|
||||
This is a constant and is taken from the config.
|
||||
"""
|
||||
self.expert_rearrangement_step_interval: int = 0
|
||||
|
||||
@staticmethod
|
||||
def build_initial_global_physical_to_logical_map(
|
||||
num_routed_experts: int,
|
||||
num_redundant_experts: int,
|
||||
) -> Sequence[int]:
|
||||
"""
|
||||
Build an initial expert arrangement using the following structure:
|
||||
[original routed experts, redundant experts]
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map (Sequence[int]): A list of integers,
|
||||
where each integer is the index of the logical expert
|
||||
that the corresponding physical expert maps to.
|
||||
"""
|
||||
global_physical_to_logical_map = list(range(num_routed_experts))
|
||||
global_physical_to_logical_map += [
|
||||
i % num_routed_experts for i in range(num_redundant_experts)
|
||||
]
|
||||
return global_physical_to_logical_map
|
||||
|
||||
def validate_ep_configuration(self, new_model: MixtureOfExperts):
|
||||
"""
|
||||
Validate that the expert parallel configuration of
|
||||
the new model is the same as the existing models.
|
||||
"""
|
||||
if len(self.model_states) > 0:
|
||||
model = next(iter(self.model_states.values())).model
|
||||
if (
|
||||
model.num_routed_experts != new_model.num_routed_experts
|
||||
or model.num_redundant_experts != new_model.num_redundant_experts
|
||||
or model.num_physical_experts != new_model.num_physical_experts
|
||||
or model.num_logical_experts != new_model.num_logical_experts
|
||||
or model.num_expert_groups != new_model.num_expert_groups
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Model: {} "
|
||||
"with config {} "
|
||||
"{} {} {} {} "
|
||||
"mismatch with new model {} "
|
||||
"with config {} "
|
||||
"{} {} {} {}".format(
|
||||
type(model),
|
||||
model.num_routed_experts,
|
||||
model.num_redundant_experts,
|
||||
model.num_physical_experts,
|
||||
model.num_logical_experts,
|
||||
model.num_expert_groups,
|
||||
type(new_model),
|
||||
new_model.num_routed_experts,
|
||||
new_model.num_redundant_experts,
|
||||
new_model.num_physical_experts,
|
||||
new_model.num_logical_experts,
|
||||
new_model.num_expert_groups,
|
||||
)
|
||||
)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
global_expert_load: torch.Tensor | None = None,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
"""
|
||||
self.validate_ep_configuration(model)
|
||||
physical_to_logical_map_list = (
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
model.num_routed_experts,
|
||||
model.num_redundant_experts,
|
||||
)
|
||||
)
|
||||
physical_to_logical_map = torch.tensor(
|
||||
physical_to_logical_map_list,
|
||||
device=self.device,
|
||||
)
|
||||
# Assuming 8 GPUs per node, this supports up to
|
||||
# (1023 + 1) / 8 = 128 nodes for now.
|
||||
# TODO(rui): make this configurable
|
||||
MAX_EXPERT_REDUNDANCY = 1023
|
||||
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
|
||||
f"num_redundant_experts {model.num_redundant_experts} "
|
||||
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}"
|
||||
)
|
||||
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
|
||||
logical_to_physical_map = torch.full(
|
||||
(model.num_logical_experts, max_slots_per_logical_expert),
|
||||
-1,
|
||||
device=self.device,
|
||||
)
|
||||
logical_replica_count = torch.zeros(
|
||||
(model.num_logical_experts,),
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
for i in range(model.num_physical_experts):
|
||||
logical_idx = physical_to_logical_map[i]
|
||||
logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i
|
||||
logical_replica_count[logical_idx] += 1
|
||||
|
||||
# Duplicate initial mapping for all layers
|
||||
physical_to_logical_map = (
|
||||
physical_to_logical_map.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
logical_to_physical_map = (
|
||||
logical_to_physical_map.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
logical_replica_count = (
|
||||
logical_replica_count.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
expert_load_pass = torch.zeros(
|
||||
(model.num_moe_layers, model.num_physical_experts),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expert_load_window_size = self.parallel_config.eplb_config.window_size
|
||||
expert_load_window = torch.zeros(
|
||||
(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
model.num_physical_experts,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Set the initial progress of rearrangement to 3/4
|
||||
eplb_step_interval = self.parallel_config.eplb_config.step_interval
|
||||
self.expert_rearrangement_step = max(
|
||||
0, eplb_step_interval - eplb_step_interval // 4
|
||||
)
|
||||
self.expert_rearrangement_step_interval = eplb_step_interval
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
self.expert_rearrangement_step = 0
|
||||
|
||||
self.model_states[model_config.compute_hash()] = EplbModelState(
|
||||
physical_to_logical_map,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
expert_load_pass,
|
||||
expert_load_window,
|
||||
model_config.model,
|
||||
model,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
is_dummy: bool = False,
|
||||
is_profile: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Step the EPLB state.
|
||||
|
||||
Args:
|
||||
is_dummy (bool): If `True`, this is a dummy step and the load
|
||||
metrics recorded in this forward pass will not count.
|
||||
Defaults to `False`.
|
||||
is_profile (bool): If `True`, perform a dummy rearrangement
|
||||
with maximum communication cost. This is used in
|
||||
`profile_run` to reserve enough memory
|
||||
for the communication buffer.
|
||||
log_stats (bool): If `True`, log the expert load metrics.
|
||||
|
||||
# Stats
|
||||
The metrics are all summed up across layers.
|
||||
- `avg_tokens`: The average load across ranks.
|
||||
- `max_tokens`: The maximum load across ranks.
|
||||
- `balancedness`: The ratio of average load to maximum load.
|
||||
"""
|
||||
|
||||
if is_profile:
|
||||
self.rearrange(is_profile=True)
|
||||
return
|
||||
|
||||
if is_dummy:
|
||||
# Do not record load metrics for dummy steps
|
||||
for eplb_model_state in self.model_states.values():
|
||||
eplb_model_state.expert_load_pass.zero_()
|
||||
|
||||
if log_stats:
|
||||
# Sync the expert load pass for each model (main and drafter).
|
||||
# expert_load_pass: (num_moe_layers, num_physical_experts)
|
||||
expert_load_pass_list = self._sync_load_pass()
|
||||
ep_group = get_ep_group().device_group
|
||||
for expert_load_pass, eplb_model_state in zip(
|
||||
expert_load_pass_list, self.model_states.values()
|
||||
):
|
||||
# num_tokens_per_rank: (num_moe_layers, num_ranks)
|
||||
num_tokens_per_rank = (
|
||||
expert_load_pass.reshape(
|
||||
expert_load_pass.shape[0], ep_group.size(), -1
|
||||
)
|
||||
.sum(dim=-1)
|
||||
.float()
|
||||
)
|
||||
|
||||
# Compute balancedness ratio:
|
||||
# for each layer:
|
||||
# (mean load across ranks) / (max load across ranks)
|
||||
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
|
||||
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
|
||||
|
||||
# Just to make type checker happy
|
||||
tokens_tensors: list[float] = torch.stack(
|
||||
[avg_tokens_tensor, max_tokens_tensor]
|
||||
).tolist()
|
||||
avg_tokens, max_tokens = tokens_tensors
|
||||
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
|
||||
|
||||
if ep_group.rank() == 0:
|
||||
logger.info(
|
||||
"EPLB step: %d for model %s: avg_tokens=%.2f, "
|
||||
"max_tokens=%d, balancedness=%.4f",
|
||||
self.expert_rearrangement_step,
|
||||
eplb_model_state.model_name,
|
||||
avg_tokens,
|
||||
max_tokens,
|
||||
balancedness,
|
||||
)
|
||||
|
||||
# Update the expert load sliding window
|
||||
if not is_dummy:
|
||||
for eplb_model_state in self.model_states.values():
|
||||
eplb_model_state.expert_load_window[self.expert_load_window_step] = (
|
||||
eplb_model_state.expert_load_pass.clone()
|
||||
)
|
||||
eplb_model_state.expert_load_pass.zero_()
|
||||
|
||||
self.expert_load_window_step += 1
|
||||
if self.expert_load_window_step >= self.expert_load_window_size:
|
||||
self.expert_load_window_step = 0
|
||||
|
||||
# Step the expert rearrangement step
|
||||
# Note that even if this is a dummy step, we still increment the
|
||||
# rearrangement step and perform rearrangement to ensure all ranks are
|
||||
# performing collective communication.
|
||||
self.expert_rearrangement_step += 1
|
||||
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
|
||||
self.expert_rearrangement_step = 0
|
||||
self.rearrange()
|
||||
|
||||
def rearrange(
|
||||
self,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_loads: list[torch.Tensor] | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Rearrange the experts according to the current load.
|
||||
|
||||
Args:
|
||||
is_profile (bool): If `True`, perform a dummy rearrangement.
|
||||
This is used in `profile_run` to reserve enough memory,
|
||||
no memory movement will be performed. Default is False.
|
||||
execute_shuffle (bool): If `True`, execute the shuffle
|
||||
in elastic expert parallel (EEP). Default is True.
|
||||
global_expert_loads (list[torch.Tensor] | None): The global expert
|
||||
loads when scaling is done in EEP.
|
||||
List of expert loads for the main and drafter
|
||||
(when spec decode is used) models.
|
||||
rank_mapping (dict[int, int] | None): The rank mapping
|
||||
when scaling is done in EEP.
|
||||
"""
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
ep_rank = ep_group.rank()
|
||||
|
||||
time_start = None
|
||||
is_main_rank = ep_rank == 0
|
||||
if is_main_rank:
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.perf_counter()
|
||||
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
|
||||
|
||||
if global_expert_loads is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
if not execute_shuffle:
|
||||
num_models = torch.tensor(
|
||||
[len(self.model_states)], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_models, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
for eplb_model_state in self.model_states.values():
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
dtype=eplb_model_state.expert_load_window.dtype,
|
||||
device=eplb_model_state.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
|
||||
.expand_as(eplb_model_state.expert_load_window)
|
||||
.long(),
|
||||
src=eplb_model_state.expert_load_window,
|
||||
)
|
||||
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
eplb_model_state.physical_to_logical_map.shape[1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
metadata, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(
|
||||
global_expert_load_windows
|
||||
)
|
||||
if not execute_shuffle:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = eplb_model_state.physical_to_logical_map
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group, group_src=0
|
||||
)
|
||||
if not execute_shuffle:
|
||||
return global_expert_load_windows
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_windows = global_expert_loads
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
eplb_model_state = next(iter(self.model_states.values()))
|
||||
model = eplb_model_state.model
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_replicas = (
|
||||
num_replicas // ep_group.size() * num_gpus
|
||||
) # handle num replicas change
|
||||
else:
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
self.num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# Get new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
ep_group,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
if not is_profile:
|
||||
if (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
!= new_physical_to_logical_map.shape[1]
|
||||
):
|
||||
eplb_model_state.physical_to_logical_map = (
|
||||
new_physical_to_logical_map.to(
|
||||
eplb_model_state.physical_to_logical_map.device
|
||||
)
|
||||
)
|
||||
else:
|
||||
eplb_model_state.physical_to_logical_map.copy_(
|
||||
new_physical_to_logical_map
|
||||
)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert (
|
||||
max_physical_slots
|
||||
<= eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
)
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(
|
||||
0,
|
||||
eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
- max_physical_slots,
|
||||
),
|
||||
value=-1,
|
||||
)
|
||||
eplb_model_state.logical_to_physical_map.copy_(
|
||||
new_logical_to_physical_map
|
||||
)
|
||||
eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
if is_main_rank:
|
||||
assert time_start is not None
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.perf_counter()
|
||||
logger.info(
|
||||
"Rearranged experts%sin %.2f seconds.",
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
|
||||
num_models = num_models.item()
|
||||
global_expert_loads = []
|
||||
old_global_expert_indices_per_model = []
|
||||
for _ in range(num_models):
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist()
|
||||
)
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0,
|
||||
)
|
||||
global_expert_loads.append(global_expert_load)
|
||||
old_global_expert_indices_per_model.append(old_global_expert_indices)
|
||||
return global_expert_loads, old_global_expert_indices_per_model
|
||||
|
||||
@classmethod
|
||||
def get_eep_state(
|
||||
cls, parallel_config: ParallelConfig
|
||||
) -> tuple[
|
||||
list[torch.Tensor] | None,
|
||||
list[torch.Tensor] | None,
|
||||
dict[int, int] | None,
|
||||
]:
|
||||
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts.item())
|
||||
new_ep_size = get_ep_group().world_size
|
||||
global_expert_loads, old_global_expert_indices_per_model = (
|
||||
EplbState.recv_state()
|
||||
)
|
||||
|
||||
# EP configuration for all models has to be the same so as eplb config
|
||||
num_logical_experts = global_expert_loads[0].shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_local_physical_experts * new_ep_size - num_logical_experts
|
||||
)
|
||||
assert (
|
||||
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
|
||||
== 0
|
||||
)
|
||||
old_ep_size = (
|
||||
old_global_expert_indices_per_model[0].shape[1]
|
||||
// num_local_physical_experts
|
||||
)
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
return (
|
||||
global_expert_loads,
|
||||
old_global_expert_indices_per_model,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""
|
||||
All-reduce a list of tensors.
|
||||
"""
|
||||
if len(tensor_list) == 1:
|
||||
all_reduce(tensor_list[0], group=get_ep_group().device_group)
|
||||
return tensor_list
|
||||
assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
|
||||
assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
|
||||
"All tensors must have the same shape[1]."
|
||||
)
|
||||
# Concatenate, all_reduce, then unpack to original shapes.
|
||||
# We assume all tensors are 2D and shape[1] (num_physical_experts)
|
||||
# is the same across all models.
|
||||
shapes = [t.shape for t in tensor_list]
|
||||
concat_tensor = torch.cat(tensor_list, dim=0)
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
all_reduce(concat_tensor, group=ep_group)
|
||||
|
||||
all_reduce_list = []
|
||||
offset = 0
|
||||
for shape in shapes:
|
||||
all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
|
||||
offset += shape[0]
|
||||
return all_reduce_list
|
||||
|
||||
def _sync_load_pass(self) -> list[torch.Tensor]:
|
||||
"""
|
||||
Sync the expert load pass across all ranks for log stats.
|
||||
Doesn't update the expert load pass in eplb_model_state.
|
||||
"""
|
||||
load_pass_list = []
|
||||
for eplb_model_state in self.model_states.values():
|
||||
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
|
||||
return self._allreduce_list(load_pass_list)
|
||||
|
||||
|
||||
def _node_count_with_rank_mapping(
|
||||
pg: ProcessGroup | StatelessProcessGroup,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> int:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
else:
|
||||
world_size = pg.world_size
|
||||
|
||||
if world_size == 1:
|
||||
return 1
|
||||
|
||||
# Build node assignment map
|
||||
node_assignment = [0] * world_size # rank -> node_id
|
||||
next_node_id = 0
|
||||
|
||||
for current_rank in range(world_size):
|
||||
if node_assignment[current_rank] != 0:
|
||||
continue # Already assigned to a node
|
||||
|
||||
assert current_rank in rank_mapping
|
||||
if rank_mapping[current_rank] == -1:
|
||||
continue # Pending shutdown
|
||||
|
||||
# Assign current rank to a new node
|
||||
next_node_id += 1
|
||||
node_assignment[current_rank] = next_node_id
|
||||
|
||||
# Find all ranks on the same node as current_rank
|
||||
same_node_flags = in_the_same_node_as(pg, current_rank)
|
||||
for other_rank, is_same_node in enumerate(same_node_flags):
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
260
distributed/eplb/rebalance_algo.py
Normal file
260
distributed/eplb/rebalance_algo.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) for vLLM.
|
||||
|
||||
This module implements the core rearrangement algorithm.
|
||||
|
||||
The rearrangement algorithm is adapted from
|
||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||
|
||||
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
device = weight.device
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
weight_np = weight.cpu().numpy()
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices_np = np.argsort(-weight_np, axis=-1)
|
||||
|
||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0.0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
|
||||
for group in indices_np[i]:
|
||||
# Find a pack with capacity that has the lowest weight
|
||||
pack = min(
|
||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index_np[i, group] = pack
|
||||
rank_in_pack_np[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight_np[i, group]
|
||||
pack_items[pack] += 1
|
||||
|
||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(
|
||||
weight: torch.Tensor, num_phy: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g., NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map (torch.Tensor):
|
||||
[num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map (torch.Tensor):
|
||||
[num_moe_layers, num_logical_experts, X]
|
||||
logical_count (torch.Tensor):
|
||||
[num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
|
||||
perm.shape
|
||||
),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all
|
||||
logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map:
|
||||
[layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map:
|
||||
[layers, num_logical_experts, X], the replica indices for each
|
||||
expert
|
||||
expert_count:
|
||||
[layers, num_logical_experts], number of physical
|
||||
replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
431
distributed/eplb/rebalance_execute.py
Normal file
431
distributed/eplb/rebalance_execute.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The actual execution of the rearrangement.
|
||||
|
||||
This involves the exchange of expert weights between GPUs.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
all_gather,
|
||||
batch_isend_irecv,
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
|
||||
def idx_local_to_global(
|
||||
local_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a local expert index to a global expert index.
|
||||
"""
|
||||
return ep_rank * local_cnt + local_idx
|
||||
|
||||
|
||||
def idx_global_to_local(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a local expert index.
|
||||
"""
|
||||
return global_idx - ep_rank * local_cnt
|
||||
|
||||
|
||||
def global_idx_to_rank(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a rank index.
|
||||
"""
|
||||
return global_idx // local_cnt
|
||||
|
||||
|
||||
def get_ep_ranks_with_expert(
|
||||
idx: int,
|
||||
num_local_experts: int,
|
||||
old_indices: Sequence[int],
|
||||
new_indices: Sequence[int],
|
||||
) -> tuple[MutableSequence[int], MutableSequence[int]]:
|
||||
"""
|
||||
Get the ranks of the experts that need to be exchanged.
|
||||
|
||||
Args:
|
||||
idx: The index of the expert.
|
||||
num_local_experts: The number of local experts.
|
||||
old_indices: The old indices of the experts.
|
||||
new_indices: The new indices of the experts.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists:
|
||||
- The ranks of the experts that need to be sent.
|
||||
- The ranks of the experts that need to be received.
|
||||
"""
|
||||
global2rank = partial(
|
||||
global_idx_to_rank,
|
||||
local_cnt=num_local_experts,
|
||||
)
|
||||
|
||||
ranks_to_send: list[int] = []
|
||||
ranks_to_recv: list[int] = []
|
||||
|
||||
for i, e in enumerate(old_indices):
|
||||
if e == idx:
|
||||
rank = global2rank(i)
|
||||
if not ranks_to_send or ranks_to_send[-1] != rank:
|
||||
ranks_to_send.append(rank)
|
||||
|
||||
for i, e in enumerate(new_indices):
|
||||
if e == idx:
|
||||
rank = global2rank(i)
|
||||
if not ranks_to_recv or ranks_to_recv[-1] != rank:
|
||||
ranks_to_recv.append(rank)
|
||||
|
||||
# Remove those ranks that can get this expert locally.
|
||||
ranks_to_send_set = set(ranks_to_send)
|
||||
ranks_to_recv_actual = [
|
||||
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
|
||||
]
|
||||
|
||||
return ranks_to_send, ranks_to_recv_actual
|
||||
|
||||
|
||||
def shuffle_layer(
|
||||
num_local_experts: int,
|
||||
ep_rank: int,
|
||||
old_indices: Sequence[int],
|
||||
new_indices: Sequence[int],
|
||||
expert_weights: Iterable[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Perform expert weights rearrangement of one layer.
|
||||
"""
|
||||
local2global = partial(
|
||||
idx_local_to_global,
|
||||
local_cnt=num_local_experts,
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
# 0. Do nothing for experts that did not change.
|
||||
is_unchanged = [
|
||||
old_indices[local2global(i)] == new_indices[local2global(i)]
|
||||
for i in range(num_local_experts)
|
||||
]
|
||||
|
||||
# 1. Perform weight copy inside the local rank.
|
||||
is_received_locally = is_unchanged[:]
|
||||
for src in range(num_local_experts):
|
||||
src_global = local2global(src)
|
||||
for dst in range(num_local_experts):
|
||||
dst_global = local2global(dst)
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
|
||||
continue
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
buffer[dst].copy_(weight[src])
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
# 2. Initiate sending of weights.
|
||||
experts_send_loc: dict[int, int] = {}
|
||||
for src in range(num_local_experts):
|
||||
expert = old_indices[local2global(src)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_send_loc:
|
||||
continue
|
||||
experts_send_loc[expert] = src
|
||||
|
||||
# We need to sort here to match send/recv
|
||||
for expert, src in sorted(experts_send_loc.items()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
# Calculate the ranks to send by this rank
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
sender_pos = ranks_to_send.index(ep_rank)
|
||||
recv_begin = sender_pos * num_dst_per_sender
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
recv_ranks = ranks_to_recv[recv_begin:recv_end]
|
||||
|
||||
# Tackle remainders
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
recver_pos = remainder_start + sender_pos
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
|
||||
for dst in recv_ranks:
|
||||
dst_global = get_global_rank(ep_group, dst)
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
weight[src],
|
||||
dst_global,
|
||||
)
|
||||
for weight in expert_weights
|
||||
]
|
||||
|
||||
# 3. Initiate receiving of weights.
|
||||
experts_recv_loc: dict[int, int] = {}
|
||||
for dst in range(num_local_experts):
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_recv_loc:
|
||||
continue
|
||||
experts_recv_loc[expert] = dst
|
||||
|
||||
# We need to sort here to match send/recv
|
||||
for expert, dst in sorted(experts_recv_loc.items()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
# Calculate the rank to recv by this rank
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
recver_pos = ranks_to_recv.index(ep_rank)
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
if recver_pos < remainder_start:
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
|
||||
src_global = get_global_rank(ep_group, src)
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
weight[dst],
|
||||
src_global,
|
||||
)
|
||||
for weight in expert_weights_buffer
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# 5. Copy the weights from the buffer back to the original weights.
|
||||
for dst in range(num_local_experts):
|
||||
if is_unchanged[dst]:
|
||||
continue
|
||||
if is_received_locally[dst]:
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[dst])
|
||||
else:
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
src = experts_recv_loc[expert]
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[src])
|
||||
|
||||
|
||||
def rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
rank_mapping: A dictionary mapping old rank to new rank.
|
||||
"""
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_rank = ep_group.rank()
|
||||
ep_size = ep_group.size()
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
# A buffer to hold the expert weights in one layer during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
|
||||
if is_profile:
|
||||
# Maximum send size is to send all local experts to all ranks,
|
||||
# So we use a dummy `all_gather` to reserve enough communication buffer
|
||||
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
|
||||
# A `/dev/null`-like buffer to avoid real memory allocation
|
||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||
# NOTE(bowen): Needed this barrier to avoid OOM during actual
|
||||
# execution. I'm not very sure why this is needed
|
||||
torch.distributed.barrier()
|
||||
all_gather(
|
||||
dummy_recv_buffer,
|
||||
weight,
|
||||
group=ep_group,
|
||||
)
|
||||
return
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
|
||||
|
||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||
# If you figure out the reason, please let me know -- thank you!
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for layer in range(num_moe_layers):
|
||||
shuffle_layer(
|
||||
num_local_physical_experts,
|
||||
ep_rank,
|
||||
old_global_expert_indices_cpu[layer].tolist(),
|
||||
new_global_expert_indices_cpu[layer].tolist(),
|
||||
expert_weights[layer],
|
||||
expert_weights_buffer,
|
||||
ep_group,
|
||||
)
|
||||
|
||||
|
||||
def _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
new_ep_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
"""
|
||||
num_layers, old_num_physical_experts = old_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
num_local_physical_experts = old_num_physical_experts // old_ep_size
|
||||
new_num_physical_experts = new_ep_size * num_local_physical_experts
|
||||
|
||||
# Create mapped tensor with new shape, initialized to -1
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, new_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=old_global_expert_indices.dtype,
|
||||
device=old_global_expert_indices.device,
|
||||
)
|
||||
|
||||
# Handle rank mapping (scale up/down with rank changes)
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping.get(old_rank)
|
||||
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
|
||||
# This old rank exists in the new configuration
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = (
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
)
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
def _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
num_layers, new_num_physical_experts = new_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_local_physical_experts = new_num_physical_experts // new_ep_size
|
||||
old_num_physical_experts = old_ep_size * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, old_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=new_global_expert_indices.dtype,
|
||||
device=new_global_expert_indices.device,
|
||||
)
|
||||
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping[old_rank]
|
||||
if new_rank >= 0 and new_rank < new_ep_size:
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = (
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["rearrange_expert_weights_inplace"]
|
||||
Reference in New Issue
Block a user