Sync from v0.13
This commit is contained in:
3
vllm/distributed/eplb/__init__.py
Normal file
3
vllm/distributed/eplb/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Expert parallelism load balancer (EPLB)."""
|
||||
115
vllm/distributed/eplb/async_worker.py
Normal file
115
vllm/distributed/eplb/async_worker.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The async worker that transfers experts in the background.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .rebalance_execute import transfer_layer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .eplb_state import EplbState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
ep_group = get_ep_group().device_group
|
||||
rank = ep_group.rank()
|
||||
device_index = state.cuda_device_index
|
||||
|
||||
def thread_target() -> None:
|
||||
assert device_index is not None
|
||||
torch.cuda.set_device(device_index)
|
||||
cuda_stream = torch.cuda.Stream(device=device_index)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
transfer_run_periodically(
|
||||
state=state,
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
async def transfer_run_periodically(
|
||||
state: "EplbState",
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
cuda_stream: torch.cuda.Stream = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
logger.info("async worker woke up for EPLB transfer")
|
||||
|
||||
for model_state in state.model_states.values():
|
||||
if not model_state.is_async_enabled:
|
||||
continue
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
while (
|
||||
model_state.rebalanced
|
||||
and model_state.layer_to_transfer < current_num_layers
|
||||
):
|
||||
if (
|
||||
not model_state.ep_buffer_ready
|
||||
and model_state.rebalanced
|
||||
and model_state.new_physical_to_logical_map is not None
|
||||
):
|
||||
await asyncio.to_thread(model_state.buffer_lock.acquire)
|
||||
try:
|
||||
if model_state.layer_to_transfer >= current_num_layers:
|
||||
break
|
||||
|
||||
(
|
||||
model_state.is_unchanged,
|
||||
model_state.is_received_locally,
|
||||
model_state.experts_recv_loc,
|
||||
) = await transfer_layer(
|
||||
old_global_expert_indices=model_state.physical_to_logical_map,
|
||||
new_global_expert_indices=model_state.new_physical_to_logical_map,
|
||||
expert_weights=model_state.model.expert_weights,
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
layer=model_state.layer_to_transfer,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
model_state.buffer_ready_event = event
|
||||
model_state.ep_buffer_ready = 1
|
||||
finally:
|
||||
model_state.buffer_lock.release()
|
||||
else:
|
||||
if not model_state.rebalanced:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
state.rearrange_event.clear()
|
||||
1164
vllm/distributed/eplb/eplb_state.py
Normal file
1164
vllm/distributed/eplb/eplb_state.py
Normal file
File diff suppressed because it is too large
Load Diff
19
vllm/distributed/eplb/policy/__init__.py
Normal file
19
vllm/distributed/eplb/policy/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import get_args
|
||||
|
||||
from vllm.config.parallel import EPLBPolicyOption
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
from .default import DefaultEplbPolicy
|
||||
|
||||
EPLB_POLICIES = {"default": DefaultEplbPolicy}
|
||||
|
||||
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
|
||||
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
|
||||
|
||||
__all__ = [
|
||||
"AbstractEplbPolicy",
|
||||
"DefaultEplbPolicy",
|
||||
"EPLB_POLICIES",
|
||||
]
|
||||
40
vllm/distributed/eplb/policy/abstract.py
Normal file
40
vllm/distributed/eplb/policy/abstract.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractEplbPolicy(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics
|
||||
for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_ranks`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
raise NotImplementedError
|
||||
267
vllm/distributed/eplb/policy/default.py
Normal file
267
vllm/distributed/eplb/policy/default.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) for vLLM.
|
||||
|
||||
This module implements the core rearrangement algorithm.
|
||||
|
||||
The rearrangement algorithm is adapted from
|
||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||
|
||||
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
|
||||
|
||||
class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
@classmethod
|
||||
def balanced_packing(
|
||||
cls, weight: torch.Tensor, num_packs: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
device = weight.device
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
weight_np = weight.cpu().numpy()
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices_np = np.argsort(-weight_np, axis=-1)
|
||||
|
||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0.0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
|
||||
for group in indices_np[i]:
|
||||
# Find a pack with capacity that has the lowest weight
|
||||
pack = min(
|
||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index_np[i, group] = pack
|
||||
rank_in_pack_np[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight_np[i, group]
|
||||
pack_items[pack] += 1
|
||||
|
||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
@classmethod
|
||||
def replicate_experts(
|
||||
cls, weight: torch.Tensor, num_phy: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts_hierarchical(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(
|
||||
perm.size(1), dtype=torch.int64, device=perm.device
|
||||
).expand(perm.shape),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_group, num_nodes
|
||||
)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(
|
||||
group_size, dtype=torch.int64, device=group_pack_index.device
|
||||
)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_phy, num_gpus // num_nodes
|
||||
)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all
|
||||
logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_ranks
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_ranks
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
529
vllm/distributed/eplb/rebalance_execute.py
Normal file
529
vllm/distributed/eplb/rebalance_execute.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
The actual execution of the rearrangement.
|
||||
|
||||
This involves the exchange of expert weights between GPUs.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
all_gather,
|
||||
batch_isend_irecv,
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
|
||||
def idx_local_to_global(
|
||||
local_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a local expert index to a global expert index.
|
||||
"""
|
||||
return ep_rank * local_cnt + local_idx
|
||||
|
||||
|
||||
def idx_global_to_local(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a local expert index.
|
||||
"""
|
||||
return global_idx - ep_rank * local_cnt
|
||||
|
||||
|
||||
def global_idx_to_rank(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a rank index.
|
||||
"""
|
||||
return global_idx // local_cnt
|
||||
|
||||
|
||||
def get_ep_ranks_with_expert(
|
||||
idx: int,
|
||||
num_local_experts: int,
|
||||
old_indices: Sequence[int],
|
||||
new_indices: Sequence[int],
|
||||
) -> tuple[MutableSequence[int], MutableSequence[int]]:
|
||||
"""
|
||||
Get the ranks of the experts that need to be exchanged.
|
||||
|
||||
Args:
|
||||
idx: The index of the expert.
|
||||
num_local_experts: The number of local experts.
|
||||
old_indices: The old indices of the experts.
|
||||
new_indices: The new indices of the experts.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists:
|
||||
- The ranks of the experts that need to be sent.
|
||||
- The ranks of the experts that need to be received.
|
||||
"""
|
||||
global2rank = partial(
|
||||
global_idx_to_rank,
|
||||
local_cnt=num_local_experts,
|
||||
)
|
||||
|
||||
ranks_to_send: list[int] = []
|
||||
ranks_to_recv: list[int] = []
|
||||
|
||||
for i, e in enumerate(old_indices):
|
||||
if e == idx:
|
||||
rank = global2rank(i)
|
||||
if not ranks_to_send or ranks_to_send[-1] != rank:
|
||||
ranks_to_send.append(rank)
|
||||
|
||||
for i, e in enumerate(new_indices):
|
||||
if e == idx:
|
||||
rank = global2rank(i)
|
||||
if not ranks_to_recv or ranks_to_recv[-1] != rank:
|
||||
ranks_to_recv.append(rank)
|
||||
|
||||
# Remove those ranks that can get this expert locally.
|
||||
ranks_to_send_set = set(ranks_to_send)
|
||||
ranks_to_recv_actual = [
|
||||
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
|
||||
]
|
||||
|
||||
return ranks_to_send, ranks_to_recv_actual
|
||||
|
||||
|
||||
def move_to_buffer(
|
||||
num_local_experts: int,
|
||||
old_indices: Sequence[int],
|
||||
new_indices: Sequence[int],
|
||||
expert_weights: Iterable[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_group: ProcessGroup,
|
||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
||||
"""
|
||||
Perform expert weights rearrangement of one layer.
|
||||
"""
|
||||
ep_rank = ep_group.rank()
|
||||
local2global = partial(
|
||||
idx_local_to_global,
|
||||
local_cnt=num_local_experts,
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
# 0. Do nothing for experts that did not change.
|
||||
is_unchanged = [
|
||||
old_indices[local2global(i)] == new_indices[local2global(i)]
|
||||
for i in range(num_local_experts)
|
||||
]
|
||||
|
||||
# 1. Perform weight copy inside the local rank.
|
||||
is_received_locally = is_unchanged[:]
|
||||
for src in range(num_local_experts):
|
||||
src_global = local2global(src)
|
||||
for dst in range(num_local_experts):
|
||||
dst_global = local2global(dst)
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
|
||||
continue
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
buffer[dst].copy_(weight[src], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
# 2. Initiate sending of weights.
|
||||
experts_send_loc: dict[int, int] = {}
|
||||
for src in range(num_local_experts):
|
||||
expert = old_indices[local2global(src)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_send_loc:
|
||||
continue
|
||||
experts_send_loc[expert] = src
|
||||
|
||||
# We need to sort here to match send/recv
|
||||
for expert, src in sorted(experts_send_loc.items()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
# Calculate the ranks to send by this rank
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
sender_pos = ranks_to_send.index(ep_rank)
|
||||
recv_begin = sender_pos * num_dst_per_sender
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
recv_ranks = ranks_to_recv[recv_begin:recv_end]
|
||||
|
||||
# Tackle remainders
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
recver_pos = remainder_start + sender_pos
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
|
||||
for dst in recv_ranks:
|
||||
dst_global = get_global_rank(ep_group, dst)
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
weight[src],
|
||||
dst_global,
|
||||
)
|
||||
for weight in expert_weights
|
||||
]
|
||||
|
||||
# 3. Initiate receiving of weights.
|
||||
experts_recv_loc: dict[int, int] = {}
|
||||
for dst in range(num_local_experts):
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_recv_loc:
|
||||
continue
|
||||
experts_recv_loc[expert] = dst
|
||||
|
||||
# We need to sort here to match send/recv
|
||||
for expert, dst in sorted(experts_recv_loc.items()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
new_indices,
|
||||
)
|
||||
|
||||
# Calculate the rank to recv by this rank
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
recver_pos = ranks_to_recv.index(ep_rank)
|
||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||
if recver_pos < remainder_start:
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
|
||||
src_global = get_global_rank(ep_group, src)
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
weight[dst],
|
||||
src_global,
|
||||
)
|
||||
for weight in expert_weights_buffer
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return is_unchanged, is_received_locally, experts_recv_loc
|
||||
|
||||
|
||||
def move_from_buffer(
|
||||
expert_weights: Iterable[torch.Tensor],
|
||||
expert_weights_buffer: list[torch.Tensor],
|
||||
is_unchanged: list[bool],
|
||||
is_received_locally: list[bool],
|
||||
experts_recv_loc: dict[int, int],
|
||||
new_indices: Sequence[int],
|
||||
ep_group: ProcessGroup,
|
||||
) -> None:
|
||||
ep_rank = ep_group.rank()
|
||||
num_local_experts = len(is_unchanged)
|
||||
|
||||
local2global = partial(
|
||||
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
|
||||
)
|
||||
|
||||
for dst in range(num_local_experts):
|
||||
if is_unchanged[dst]:
|
||||
continue
|
||||
if is_received_locally[dst]:
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[dst], non_blocking=True)
|
||||
else:
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
src = experts_recv_loc[expert]
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[src], non_blocking=True)
|
||||
|
||||
|
||||
async def transfer_layer(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
layer: int = 0,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
"""
|
||||
ep_size = ep_group.size()
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices[layer].tolist(),
|
||||
new_indices=new_global_expert_indices[layer].tolist(),
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
cuda_stream=cuda_stream,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
return is_unchanged, is_received_locally, experts_recv_loc
|
||||
|
||||
|
||||
def rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
|
||||
The value of the indices arguments are logical indices of the experts,
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
rank_mapping: A dictionary mapping old rank to new rank.
|
||||
"""
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_size = ep_group.size()
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
# A buffer to hold the expert weights in one layer during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
|
||||
if is_profile:
|
||||
# Maximum send size is to send all local experts to all ranks,
|
||||
# So we use a dummy `all_gather` to reserve enough communication buffer
|
||||
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
|
||||
# A `/dev/null`-like buffer to avoid real memory allocation
|
||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||
# NOTE(bowen): Needed this barrier to avoid OOM during actual
|
||||
# execution. I'm not very sure why this is needed
|
||||
torch.distributed.barrier()
|
||||
all_gather(
|
||||
dummy_recv_buffer,
|
||||
weight,
|
||||
group=ep_group,
|
||||
)
|
||||
return
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
|
||||
|
||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||
# If you figure out the reason, please let me know -- thank you!
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for layer in range(num_moe_layers):
|
||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices_cpu[layer].tolist(),
|
||||
new_indices=new_global_expert_indices_cpu[layer].tolist(),
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
cuda_stream=None,
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer],
|
||||
expert_weights_buffer=expert_weights_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
experts_recv_loc=experts_recv_loc,
|
||||
new_indices=new_global_expert_indices[layer].tolist(),
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
|
||||
def _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
new_ep_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
"""
|
||||
num_layers, old_num_physical_experts = old_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
num_local_physical_experts = old_num_physical_experts // old_ep_size
|
||||
new_num_physical_experts = new_ep_size * num_local_physical_experts
|
||||
|
||||
# Create mapped tensor with new shape, initialized to -1
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, new_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=old_global_expert_indices.dtype,
|
||||
device=old_global_expert_indices.device,
|
||||
)
|
||||
|
||||
# Handle rank mapping (scale up/down with rank changes)
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping.get(old_rank)
|
||||
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
|
||||
# This old rank exists in the new configuration
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = (
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
)
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
def _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
num_layers, new_num_physical_experts = new_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_local_physical_experts = new_num_physical_experts // new_ep_size
|
||||
old_num_physical_experts = old_ep_size * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, old_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=new_global_expert_indices.dtype,
|
||||
device=new_global_expert_indices.device,
|
||||
)
|
||||
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping[old_rank]
|
||||
if new_rank >= 0 and new_rank < new_ep_size:
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = (
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["transfer_layer", "move_from_buffer"]
|
||||
Reference in New Issue
Block a user