Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Expert parallelism load balancer (EPLB)."""

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The async worker that transfers experts in the background.
"""
import asyncio
import threading
from typing import TYPE_CHECKING
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.parallel_state import get_eplb_group
from vllm.logger import init_logger
from .rebalance_execute import transfer_layer
if TYPE_CHECKING:
from .eplb_state import EplbModelState, EplbState
logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
rank = eplb_group.rank()
device_index = state.cuda_device_index
assert state.is_async
def thread_target() -> None:
assert device_index is not None
torch.cuda.set_device(device_index)
cuda_stream = torch.cuda.Stream(device=device_index)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
transfer_run_periodically(
state=state,
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
finally:
loop.close()
thread = threading.Thread(target=thread_target, daemon=True)
thread.start()
return thread
def run_rebalance_experts(
model_state: "EplbModelState",
eplb_state: "EplbState",
physical_to_logical_map_cpu: torch.Tensor,
) -> None:
assert model_state.eplb_stats is not None
eplb_stats = model_state.eplb_stats
# Wait for the main thread's all-reduce and clone to complete before
# accessing the global_expert_load_window tensor.
assert model_state.window_ready_event is not None
model_state.window_ready_event.wait()
model_state.window_ready_event = None
# Move the global expert load window to CPU for computation.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
global_expert_load_window,
eplb_stats.num_replicas,
eplb_stats.num_groups,
eplb_stats.num_nodes,
eplb_stats.num_gpus,
physical_to_logical_map_cpu,
)
assert new_physical_to_logical_map.device == torch.device("cpu")
model_state.new_physical_to_logical_map = new_physical_to_logical_map
max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica
async def transfer_run_periodically(
state: "EplbState",
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
logger.info("async worker woke up for EPLB transfer")
assert state.is_async
for model_state in state.model_states.values():
rebalancing_algorithm_executed = False
physical_to_logical_map_cpu = None
current_num_layers = model_state.model.num_moe_layers
while (
model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers
):
if not model_state.ep_buffer_ready and model_state.rebalanced:
# Polling the lock directly in the async thread avoids
# the thread switch overhead of asyncio.to_thread.
# This is typically faster than offloading to a worker thread.
while not model_state.buffer_lock.acquire(blocking=False):
await asyncio.sleep(0)
try:
if model_state.layer_to_transfer >= current_num_layers:
break
if (
not rebalancing_algorithm_executed
or model_state.new_physical_to_logical_map is None
):
# Move the physical_to_logical_map to CPU
# for rebalancing and transfer_layer.
physical_to_logical_map_cpu = (
model_state.physical_to_logical_map.cpu()
)
run_rebalance_experts(
model_state, state, physical_to_logical_map_cpu
)
rebalancing_algorithm_executed = True
logger.info(
"Async worker computed new indices for model %s",
model_state.model_name,
)
assert model_state.new_physical_to_logical_map is not None
assert physical_to_logical_map_cpu is not None
layer_idx = model_state.layer_to_transfer
old_layer_indices = physical_to_logical_map_cpu[layer_idx]
new_layer_indices = model_state.new_physical_to_logical_map[
layer_idx
]
# Wait for the main thread to finish consuming the buffer
# before initiating an EPLB transfer on another layer.
if model_state.buffer_consumed_event is not None:
cuda_stream.wait_event(model_state.buffer_consumed_event)
model_state.buffer_consumed_event = None
(
model_state.is_unchanged,
model_state.is_received_locally,
model_state.recv_metadata,
) = await transfer_layer(
old_layer_indices=old_layer_indices,
new_layer_indices=new_layer_indices,
expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer,
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)
model_state.buffer_ready_event = event
model_state.ep_buffer_ready = 1
finally:
model_state.buffer_lock.release()
else:
if not model_state.rebalanced:
break
await asyncio.sleep(0.001)
state.rearrange_event.clear()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for EPLB (Expert Parallel Load Balancing)."""
import os
from vllm.config import ParallelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
"""
Override environment variables for EPLB when specific conditions are met.
Args:
parallel_config: The parallel configuration object.
"""
is_data_parallel = parallel_config.data_parallel_size > 1
is_eplb_enabled = parallel_config.enable_eplb
async_eplb = parallel_config.eplb_config.use_async
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
# DeepEP low-latency backend.
#
# The hang happens when two ranks interleave kernel launches differently
# between NCCL collectives (used by async EPLB weight exchange) and DeepEP
# low-latency (LL) kernels. DeepEP LL uses a cooperative launch and tries
# to reserve a large fraction of the GPU's SMs; if those SMs are currently
# occupied by NCCL, the DeepEP LL launch blocks until enough SMs are
# freed.
#
# If rank A enters DeepEP LL in main thread while rank B is still executing
# NCCL in async thread, rank A can block waiting for SMs, while rank B can
# block inside NCCL waiting for rank A to participate in the collective.
# This circular wait causes a deadlock.
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
# cooperative kernel to launch and complete, breaking the deadlock.
# See: https://github.com/deepseek-ai/DeepEP/issues/496
if is_data_parallel and is_eplb_enabled and is_deepep_ll and async_eplb:
current_value_str = os.getenv("NCCL_MAX_CTAS")
if current_value_str and current_value_str.isdigit():
return
override_value = 8
os.environ["NCCL_MAX_CTAS"] = str(override_value)
logger.info_once(
f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
"for expert parallel with EPLB and deepep_low_latency backend",
scope="global",
)

View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import get_args
from vllm.config.parallel import EPLBPolicyOption
from .abstract import AbstractEplbPolicy
from .default import DefaultEplbPolicy
EPLB_POLICIES = {"default": DefaultEplbPolicy}
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
__all__ = [
"AbstractEplbPolicy",
"DefaultEplbPolicy",
"EPLB_POLICIES",
]

View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
class AbstractEplbPolicy(ABC):
@classmethod
@abstractmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics
for all logical experts
num_replicas: number of physical experts, must be a multiple of
`num_ranks`
num_groups: number of expert groups
num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns:
physical_to_logical_map: [layers, num_replicas], the expert
index of each replica
logical_to_physical_map: [layers, num_logical_experts, X],
the replica indices for each expert
expert_count: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
raise NotImplementedError

View File

@@ -0,0 +1,376 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def balanced_packing(
cls, weight: np.ndarray, num_packs: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
return pack_index, rank_in_pack
# Sort and get indices in decending order
indices = np.argsort(-weight, axis=-1)
pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
# Run the packing algorithm
for layer_idx in range(num_layers):
weights_row = pack_weights[layer_idx]
items_row = pack_items[layer_idx]
for group in indices[layer_idx]:
# Pick the lightest pack; full packs are masked out by inf.
pack = int(np.argmin(weights_row))
pack_index[layer_idx, group] = pack
rank_in_pack[layer_idx, group] = items_row[pack]
weights_row[pack] += weight[layer_idx, group]
items_row[pack] += 1
if items_row[pack] == groups_per_pack:
# Mark as unavailable for future selections.
weights_row[pack] = np.inf
return pack_index, rank_in_pack
@classmethod
def replicate_experts(
cls, weight: np.ndarray, num_phy: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
logcnt = np.ones((n, num_log), dtype=np.int64)
arangen = np.arange(n, dtype=np.int64)
for i in range(num_log, num_phy):
redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices
replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, replica_idx, logcnt
@classmethod
def rebalance_experts_hierarchical(
cls,
weight: np.ndarray,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: np.ndarray) -> np.ndarray:
inv = np.empty_like(perm)
row_idx = np.arange(perm.shape[0])[:, None]
col_idx = np.arange(perm.shape[1], dtype=np.int64)
inv[row_idx, perm] = col_idx
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
axis=-1
)
group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes
)
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
* group_size
)
+ np.arange(group_size, dtype=np.int64)
).reshape(num_layers, num_logical_experts)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# Reorder weights into the node-local layout so replication is done per node.
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
-1, num_logical_experts // num_nodes
)
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# Effective per-physical load = logical load divided by replica count.
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes
)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
# Reorder node-local logical indices into the post-packing physical order.
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
pphy2mlog = (
pphy2mlog.reshape(num_layers, num_nodes, -1)
+ np.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
dtype=np.int64,
)[None, :, None]
).reshape(num_layers, -1)
# Map node-local logical indices back to global logical expert ids.
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
# Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
num_layers, -1
)
# Convert replica counts back to the original logical ordering.
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: np.ndarray,
phy_replicas_idx: np.ndarray,
num_ranks: int,
old_phy2log: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
fill any remaining available slots. This is applied only when the number of GPUs
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
slots_per_gpu = num_phy_experts // num_ranks
num_layers = phy2log.shape[0]
post_phy2log = phy2log.copy()
post_phy_replicas_idx = phy_replicas_idx.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Experts across all layers for this GPU
old_local = old_phy2log[:, start:end] # [layers, slots]
new_local = phy2log[:, start:end] # [layers, slots]
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots
for slot_idx in range(slots_per_gpu):
# matches: [layers, slots], True where new local experts have
# the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
layer_indices, matched_new_positions
]
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [layers, slots]
fill_mask = ~preserved_positions # [layers, slots]
if remaining_mask.any() and fill_mask.any():
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.
large = slots_per_gpu + 1
# Priorities: keep original index for available spots, set sentinel
# for unavailable; lower is earlier.
remaining_priority = np.where(remaining_mask, idx_base, large)
fill_priority = np.where(fill_mask, idx_base, large)
# Sort to get ordered indices of available src/dst positions per layer.
remaining_indices = np.argsort(remaining_priority, axis=1)
fill_indices = np.argsort(fill_priority, axis=1)
# Fill count per layer (cannot exceed either side).
remaining_counts = remaining_mask.sum(axis=1)
fill_counts = fill_mask.sum(axis=1)
take_counts = np.minimum(remaining_counts, fill_counts)
# Assign remaining new experts to remaining slots per layer.
for layer_idx in range(num_layers):
k = int(take_counts[layer_idx])
if k <= 0:
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
return post_phy2log, post_phy_replicas_idx
@classmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
device = weight.device
num_layers, num_logical_experts = weight.shape
weight_np = weight.float().cpu().numpy()
old_phy2log_np = (
old_global_expert_indices.cpu().numpy()
if old_global_expert_indices is not None
else None
)
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, num_groups, num_nodes, num_ranks
)
)
else:
# use global load-balance policy
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, 1, 1, num_ranks
)
)
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy_np = np.full(
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
)
layer_indices = np.arange(num_layers)[:, None]
replica_indices = np.tile(
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
)
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
phy2log = torch.from_numpy(phy2log_np).to(device)
log2phy = torch.from_numpy(log2phy_np).to(device)
logcnt = torch.from_numpy(logcnt_np).to(device)
return phy2log, log2phy, logcnt

View File

@@ -0,0 +1,708 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
import torch
from torch.distributed import (
P2POp,
ProcessGroup,
all_gather,
batch_isend_irecv,
get_global_rank,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class RecvMetadata:
"""Metadata describing remote receives during EPLB rebalancing."""
recv_primary_mask: np.ndarray
"""Mask of (num_local_experts,) indicating primary experts received."""
recv_count: int
"""Number of received experts for the layer."""
recv_expert_ids: np.ndarray
"""Expert ids (num_local_experts,) of remote primary experts."""
recv_dst_rows: np.ndarray
"""Target expert indices (num_local_experts,) in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
def get_ep_ranks_with_experts_batch(
expert_ids: np.ndarray,
num_local_experts: int,
old_indices: np.ndarray,
new_indices: np.ndarray,
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
"""
Get the ranks of the experts that need to be exchanged.
Args:
expert_ids: 1D array of expert indices to query.
num_local_experts: The number of local experts.
old_indices: The old indices of the experts.
new_indices: The new indices of the experts.
Returns:
A tuple of two dictionaries mapping expert_id to:
- ranks_to_send: The ranks that have this expert and need to send.
- ranks_to_recv: The ranks that need to receive this expert.
"""
ranks_to_send_map: dict[int, list[int]] = {}
ranks_to_recv_map: dict[int, list[int]] = {}
# Fast path: if no experts, return empty dicts
if expert_ids.size == 0:
return ranks_to_send_map, ranks_to_recv_map
unique_experts = np.unique(expert_ids)
num_positions = len(old_indices)
position_indices = np.arange(num_positions, dtype=np.int32)
# Vectorized approach: find all positions matching any query expert in one pass
# Use np.isin to get boolean masks for all relevant positions at once
old_relevant_mask = np.isin(old_indices, unique_experts)
new_relevant_mask = np.isin(new_indices, unique_experts)
# Process old_indices (send ranks)
if np.any(old_relevant_mask):
old_relevant_positions = position_indices[old_relevant_mask]
old_relevant_experts = old_indices[old_relevant_mask]
old_relevant_ranks = old_relevant_positions // num_local_experts
# Sort by expert first, then by position (to maintain first-appearance order)
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
sorted_experts = old_relevant_experts[sort_order]
sorted_ranks = old_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks in order of first appearance
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
ranks_to_send_map[expert] = unique_ranks.tolist()
# Process new_indices (recv ranks)
if np.any(new_relevant_mask):
new_relevant_positions = position_indices[new_relevant_mask]
new_relevant_experts = new_indices[new_relevant_mask]
new_relevant_ranks = new_relevant_positions // num_local_experts
# Sort by expert first, then by position
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
sorted_experts = new_relevant_experts[sort_order]
sorted_ranks = new_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks and exclude local copies
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
# Remove ranks that have local copies (in send map)
send_ranks_set = set(ranks_to_send_map.get(expert, []))
recv_ranks_actual = [
int(r) for r in unique_ranks if r not in send_ranks_set
]
ranks_to_recv_map[expert] = recv_ranks_actual
# Handle experts that only appear in old (send only) or new (recv only)
for expert in unique_experts:
expert = int(expert)
if expert not in ranks_to_send_map:
ranks_to_send_map[expert] = []
if expert not in ranks_to_recv_map:
ranks_to_recv_map[expert] = []
return ranks_to_send_map, ranks_to_recv_map
def move_to_buffer(
num_local_experts: int,
old_indices: np.ndarray,
new_indices: np.ndarray,
expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
) -> MoveToBufferResult:
"""
Rearranges expert weights during EPLB rebalancing.
Args:
num_local_experts: Number of local experts.
old_indices: (num_experts_total,) ndarray of current (old)
global-to-local expert assignments.
new_indices: (num_experts_total,) ndarray of desired (new)
global-to-local assignments after rebalance.
expert_weights: Original expert weights for the layer.
expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
assert old_indices.shape == new_indices.shape
ep_rank = ep_group.rank()
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = new_indices[local_global]
# Unchanged mask
is_unchanged = old_local_expert_ids == new_local_expert_ids
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_received_locally = np.logical_or(
is_unchanged, np.logical_and(new_valid, can_recv_local)
)
# Send map: first src row per unique expert present locally in old mapping
send_count = 0
valid_old = old_local_expert_ids != -1
if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
send_count = int(uniq_experts.shape[0])
send_expert_ids[:send_count] = uniq_experts
send_src_rows[:send_count] = src_rows
# Recv map: primary dst per unique expert needed remotely
recv_count = 0
need_recv_mask = np.logical_and(~is_received_locally, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
dst_rows = desired_dsts[uniq_indices]
recv_count = int(uniq_recv_experts.shape[0])
recv_expert_ids[:recv_count] = uniq_recv_experts
recv_dst_rows[:recv_count] = dst_rows
recv_primary_mask[dst_rows] = True
eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)
# 1. Local moves into tmp buffers
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
expert_to_src_map = dict(
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
)
for dst in dest_indices:
expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1)
if src_local != -1:
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
# Pre-compute global ranks mapping
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
# 2. Post sends
if send_count > 0:
experts = send_expert_ids[:send_count]
srcs = send_src_rows[:send_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
new_indices,
)
for expert, src in zip(experts.tolist(), srcs.tolist()):
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# 3. Post recvs
if recv_count > 0:
experts = recv_expert_ids[:recv_count]
dsts = recv_dst_rows[:recv_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
new_indices,
)
for expert, dst in zip(experts.tolist(), dsts.tolist()):
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# wait for the communication to finish
return (
is_unchanged,
is_received_locally,
RecvMetadata(
recv_primary_mask=recv_primary_mask,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
),
)
def move_from_buffer(
expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
new_indices: np.ndarray,
ep_rank: int,
) -> None:
"""
Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing.
Args:
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group.
"""
recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
if bool(dest_mask_np.any()):
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices:
for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst], non_blocking=True)
if recv_count == 0:
return
# Duplicate remote received rows to non-primary duplicate dsts
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1),
)
# All received experts are unique in the destination, so no need to copy duplicates
if not bool(duplicate_mask.any()):
return
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
prim_experts = recv_expert_ids[:recv_count]
prim_dsts = recv_dst_rows[:recv_count]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
return
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in expert_weights:
w[dst].copy_(w[src], non_blocking=True)
async def transfer_layer(
old_layer_indices: torch.Tensor,
new_layer_indices: torch.Tensor,
expert_weights: Sequence[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
) -> MoveToBufferResult:
"""
Rearranges the expert weights in place according to the new expert indices.
The value of the indices arguments are logical indices of the experts,
while keys are physical.
Args:
old_layer_indices: Shape (num_physical_experts,).
new_layer_indices: Shape (num_physical_experts,).
expert_weights: Iterable of weight tensors for this layer, each with shape
(num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection.
expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged.
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
ep_size = ep_group.size()
if rank_mapping is not None:
# Add a layer dimension for compatibility with mapping functions
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
if len(rank_mapping) == ep_group.size():
# scale down
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
new_layer_indices_2d,
rank_mapping,
)
else:
# scale up
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
old_layer_indices_2d,
rank_mapping,
ep_group.size(),
)
# Remove the layer dimension
old_layer_indices = old_layer_indices_2d.squeeze(0)
new_layer_indices = new_layer_indices_2d.squeeze(0)
assert old_layer_indices.shape == new_layer_indices.shape
num_physical_experts = old_layer_indices.shape[0]
assert len(expert_weights[0]) >= 1
num_local_physical_experts = expert_weights[0].shape[0]
assert num_physical_experts == ep_size * num_local_physical_experts
old_layer_indices_np = old_layer_indices.cpu().numpy()
new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_layer_indices_np,
new_indices=new_layer_indices_np,
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream,
ep_group=ep_group,
)
return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
"""
Rearranges the expert weights in place according to the new expert indices.
The value of the indices arguments are logical indices of the experts,
while keys are physical.
Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
rank_mapping: A dictionary mapping old rank to new rank.
"""
if rank_mapping is not None:
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
assert len(expert_weights[0]) >= 1
num_local_physical_experts = expert_weights[0][0].shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
if is_profile:
# Reserve communication buffers via a minimal dummy all_gather on first layer
for weight, buffer in zip(expert_weights[0], weights_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
return
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
cuda_stream=None,
ep_group=ep_group,
)
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx],
ep_rank=ep_group.rank(),
)
def _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
new_ep_size: int,
) -> torch.Tensor:
"""
Map the old global expert indices to the new global expert indices.
Args:
old_global_expert_indices:
Shape (num_layers, old_ep_size * num_local_physical_experts).
rank_mapping: Mapping from old rank to new rank.
new_ep_size: New expert parallelism size.
Returns:
Mapped expert indices with shape
(num_layers, new_ep_size * num_local_physical_experts).
"""
num_layers, old_num_physical_experts = old_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
num_local_physical_experts = old_num_physical_experts // old_ep_size
new_num_physical_experts = new_ep_size * num_local_physical_experts
# Create mapped tensor with new shape, initialized to -1
mapped_expert_indices = torch.full(
(num_layers, new_num_physical_experts),
fill_value=-1,
dtype=old_global_expert_indices.dtype,
device=old_global_expert_indices.device,
)
# Handle rank mapping (scale up/down with rank changes)
for old_rank in range(old_ep_size):
new_rank = rank_mapping.get(old_rank)
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
# This old rank exists in the new configuration
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, new_start_idx:new_end_idx] = (
old_global_expert_indices[:, old_start_idx:old_end_idx]
)
# If new_rank is None or >= new_ep_size, the experts remain -1
# (scale down case)
return mapped_expert_indices
def _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
) -> torch.Tensor:
num_layers, new_num_physical_experts = new_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_local_physical_experts = new_num_physical_experts // new_ep_size
old_num_physical_experts = old_ep_size * num_local_physical_experts
mapped_expert_indices = torch.full(
(num_layers, old_num_physical_experts),
fill_value=-1,
dtype=new_global_expert_indices.dtype,
device=new_global_expert_indices.device,
)
for old_rank in range(old_ep_size):
new_rank = rank_mapping[old_rank]
if new_rank >= 0 and new_rank < new_ep_size:
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, old_start_idx:old_end_idx] = (
new_global_expert_indices[:, new_start_idx:new_end_idx]
)
return mapped_expert_indices
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]