Support updating expert locations dynamically (#6388)

This commit is contained in:
fzyzcjy
2025-05-22 12:59:33 +08:00
committed by GitHub
parent 121f92c583
commit fc992a09f9
5 changed files with 723 additions and 0 deletions

View File

@@ -22,6 +22,7 @@ import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers import deepseek_eplb
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
@@ -207,6 +208,26 @@ class ExpertLocationMetadata:
),
)
# -------------------------------- mutation ------------------------------------
def update(
self,
other: "ExpertLocationMetadata",
):
for field in [
"ep_size",
]:
assert getattr(self, field) == getattr(other, field)
for field in [
"physical_to_logical_map",
"logical_to_all_physical_map",
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
dst = getattr(self, field)
dst[...] = getattr(other, field)
# -------------------------------- usage ------------------------------------
def logical_to_all_physical(

View File

@@ -0,0 +1,420 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from typing import Dict, List, Tuple
import torch
import torch.distributed
from torch.distributed import P2POp
from sglang.srt.managers.expert_location import (
ExpertLocationMetadata,
get_global_expert_location_metadata,
)
logger = logging.getLogger(__name__)
def update_expert_location(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values()))
)
world_size = torch.distributed.get_world_size()
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
num_gpu_per_node = world_size // nnodes
old_physical_to_logical_map = (
old_expert_location_metadata.physical_to_logical_map.tolist()
)
new_physical_to_logical_map = (
new_expert_location_metadata.physical_to_logical_map.tolist()
)
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
update_expert_weights_single_layer(
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
temp_buffers=temp_buffers,
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node,
rank=rank,
)
def create_temp_buffers(sample_tensors):
return [torch.empty_like(tensor) for tensor in sample_tensors]
def update_expert_weights_single_layer(
routed_experts_weights: List[torch.Tensor],
temp_buffers: List[torch.Tensor],
old_physical_to_logical_map: List[int], # (num_physical_Experts,)
new_physical_to_logical_map: List[int], # (num_physical_Experts,)
num_local_physical_experts: int,
num_gpu_per_node: int,
rank: int,
debug: bool = False,
):
assert all(
tensor.shape[0] == num_local_physical_experts
for tensor in routed_experts_weights
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
output_logs = [] if debug else None
num_physical_experts = len(old_physical_to_logical_map)
num_tensors = len(routed_experts_weights)
self_node_id = rank // num_gpu_per_node
local_expert_location_range = (
rank * num_local_physical_experts,
(rank + 1) * num_local_physical_experts,
)
def _entrypoint():
# List[Tuple[logical_expert_id, List[P2POp]]]
p2p_op_infos: List[Tuple[int, List[P2POp]]] = []
# List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]
buffer2weight_copy_infos: List[Tuple[int, int]] = []
_handle_recv(buffer2weight_copy_infos, p2p_op_infos)
_create_isend_ops(p2p_op_infos)
_execute_p2p_ops(p2p_op_infos)
_execute_buffer2weight_copies(buffer2weight_copy_infos)
if debug:
output_logs.append(f"{p2p_op_infos=}")
output_logs.append(f"{buffer2weight_copy_infos=}")
def _handle_recv(buffer2weight_copy_infos, p2p_op_infos):
for dst_expert_location in range(*local_expert_location_range):
_handle_recv_of_dst_expert_location(
dst_expert_location, buffer2weight_copy_infos, p2p_op_infos
)
def _handle_recv_of_dst_expert_location(
dst_expert_location: int, buffer2weight_copy_infos, p2p_op_infos
):
logical_expert_id = new_physical_to_logical_map[dst_expert_location]
# case 1: unchanged
if old_physical_to_logical_map[dst_expert_location] == logical_expert_id:
if debug:
output_logs.append(
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=unchanged"
)
return
# case 2: same-gpu
for src_expert_location in range(*local_expert_location_range):
if old_physical_to_logical_map[src_expert_location] == logical_expert_id:
for i in range(num_tensors):
_get_tensor(temp_buffers, i, dst_expert_location).copy_(
_get_tensor(routed_experts_weights, i, src_expert_location)
)
buffer2weight_copy_infos.append(
(dst_expert_location, dst_expert_location)
)
if debug:
output_logs.append(
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-gpu {src_expert_location=}"
)
return
# case 3: free-rider
for src_expert_location in range(
rank * num_local_physical_experts, dst_expert_location
):
if new_physical_to_logical_map[src_expert_location] == logical_expert_id:
buffer2weight_copy_infos.append(
(src_expert_location, dst_expert_location)
)
if debug:
output_logs.append(
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=free-rider {src_expert_location=}"
)
return
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
_compute_comm_info(logical_expert_id=logical_expert_id)
)
# case 4: same-node
if rank in need_comm_self_node_dst_ranks:
chosen_src_rank = same_node_mapping.chunk_value_from_element_value(
element_value=rank
)
_create_p2p_recv_and_buffer2weight_copy(
buffer2weight_copy_infos,
p2p_op_infos,
src_rank=chosen_src_rank,
logical_expert_id=logical_expert_id,
dst_expert_location=dst_expert_location,
)
if debug:
output_logs.append(
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-node {chosen_src_rank=}"
)
return
# case 5: cross-node
# Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert
chosen_src_rank = cross_node_mapping.chunk_value_from_element_value(
element_value=rank
)
_create_p2p_recv_and_buffer2weight_copy(
buffer2weight_copy_infos,
p2p_op_infos,
src_rank=chosen_src_rank,
logical_expert_id=logical_expert_id,
dst_expert_location=dst_expert_location,
)
if debug:
output_logs.append(
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=cross-node {chosen_src_rank=}"
)
return
def _create_p2p_recv_and_buffer2weight_copy(
buffer2weight_copy_infos,
p2p_op_infos,
*,
logical_expert_id: int,
src_rank: int,
dst_expert_location: int,
):
p2p_op_infos.append(
(
logical_expert_id,
[
P2POp(
op=torch.distributed.irecv,
tensor=_get_tensor(temp_buffers, i, dst_expert_location),
peer=src_rank,
)
for i in range(num_tensors)
],
)
)
buffer2weight_copy_infos.append((dst_expert_location, dst_expert_location))
def _create_isend_ops(p2p_op_infos):
handled_logical_expert_ids = set()
for src_expert_location in range(*local_expert_location_range):
logical_expert_id = old_physical_to_logical_map[src_expert_location]
if logical_expert_id in handled_logical_expert_ids:
continue
handled_logical_expert_ids.add(logical_expert_id)
_create_isend_ops_of_logical_expert_id(
logical_expert_id, src_expert_location, p2p_op_infos
)
def _create_isend_ops_of_logical_expert_id(
logical_expert_id, src_expert_location, p2p_op_infos
):
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
_compute_comm_info(logical_expert_id=logical_expert_id)
)
same_node_dst_ranks = same_node_mapping.element_values_from_chunk_value(
chunk_value=rank
)
cross_node_dst_ranks = cross_node_mapping.element_values_from_chunk_value(
chunk_value=rank
)
all_dst_ranks = same_node_dst_ranks + cross_node_dst_ranks
if debug:
output_logs.append(
f"create_isend_ops_of_logical_expert_id {logical_expert_id=} {src_expert_location=} {same_node_dst_ranks=} {cross_node_dst_ranks=}"
)
p2p_op_infos.append(
(
logical_expert_id,
[
P2POp(
op=torch.distributed.isend,
tensor=_get_tensor(
routed_experts_weights, i, src_expert_location
),
peer=dst_rank,
)
for dst_rank in all_dst_ranks
for i in range(num_tensors)
],
)
)
def _compute_comm_info(logical_expert_id: int):
all_src_ranks = _deduplicate_ordered(
[
x // num_local_physical_experts
for x in range(num_physical_experts)
if old_physical_to_logical_map[x] == logical_expert_id
]
)
all_src_nodes = [x // num_gpu_per_node for x in all_src_ranks]
self_node_src_ranks = [
x for x in all_src_ranks if x // num_gpu_per_node == self_node_id
]
need_comm_dst_ranks = _deduplicate_ordered(
[
x // num_local_physical_experts
for x in range(num_physical_experts)
if new_physical_to_logical_map[x] == logical_expert_id
and x // num_local_physical_experts not in all_src_ranks
]
)
need_comm_self_node_dst_ranks = (
[x for x in need_comm_dst_ranks if x // num_gpu_per_node == self_node_id]
if len(self_node_src_ranks) > 0
else []
)
need_comm_cross_node_dst_ranks = [
x
for x in need_comm_dst_ranks
if (x // num_gpu_per_node) not in all_src_nodes
]
same_node_mapping = _ChunkUtils(
chunk_values=self_node_src_ranks,
element_values=need_comm_self_node_dst_ranks,
)
cross_node_mapping = _ChunkUtils(
chunk_values=all_src_ranks,
element_values=need_comm_cross_node_dst_ranks,
)
return same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks
def _execute_p2p_ops(p2p_op_infos):
sorted_infos = sorted(p2p_op_infos, key=lambda info: info[0])
p2p_ops = [op for _, ops in sorted_infos for op in ops]
if len(p2p_ops) == 0:
return
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
for (
temp_buffers_expert_location,
routed_experts_weights_expert_location,
) in buffer2weight_copy_infos:
for i in range(num_tensors):
_get_tensor(
routed_experts_weights, i, routed_experts_weights_expert_location
).copy_(_get_tensor(temp_buffers, i, temp_buffers_expert_location))
def _get_tensor(tensors, tensor_index: int, expert_location: int) -> torch.Tensor:
return tensors[tensor_index][_get_local_expert_location(expert_location)]
def _get_local_expert_location(expert_location: int) -> int:
assert (
local_expert_location_range[0]
<= expert_location
< local_expert_location_range[1]
)
return expert_location % num_local_physical_experts
_entrypoint()
return output_logs
class _ChunkUtils:
def __init__(self, *, chunk_values: List, element_values: List):
self.chunk_values = chunk_values
self.element_values = element_values
def chunk_value_from_element_value(self, element_value):
chunk_index = self._chunk_index_from_element_index(
num_elements=len(self.element_values),
num_chunks=len(self.chunk_values),
element_index=self.element_values.index(element_value),
)
return self.chunk_values[chunk_index]
def element_values_from_chunk_value(self, chunk_value) -> List:
if len(self.element_values) == 0:
return []
element_slice = self._element_slice_from_chunk_index(
num_elements=len(self.element_values),
num_chunks=len(self.chunk_values),
chunk_index=self.chunk_values.index(chunk_value),
)
return self.element_values[element_slice]
@staticmethod
def _chunk_index_from_element_index(
num_elements: int, num_chunks: int, element_index: int
) -> int:
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
num_elements_for_long_chunks = num_long_chunks * (short_chunk_size + 1)
if element_index < num_elements_for_long_chunks:
return element_index // (short_chunk_size + 1)
else:
return (
num_long_chunks
+ (element_index - num_elements_for_long_chunks) // short_chunk_size
)
@staticmethod
def _element_slice_from_chunk_index(
num_elements: int, num_chunks: int, chunk_index: int
) -> slice:
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
start = chunk_index * short_chunk_size + min(chunk_index, num_long_chunks)
end = start + short_chunk_size + int(chunk_index < num_long_chunks)
return slice(start, end)
def _deduplicate_ordered(arr: List[int]):
output = []
for item in arr:
if len(output) == 0 or item != output[-1]:
output.append(item)
return output

View File

@@ -57,6 +57,7 @@ from sglang.srt.managers.expert_distribution import (
set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
ExpertLocationMetadata,
compute_initial_expert_location_metadata,
get_global_expert_location_metadata,
set_global_expert_location_metadata,
@@ -70,6 +71,7 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor import expert_location_updater
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model
@@ -575,6 +577,16 @@ class ModelRunner:
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata
):
expert_location_updater.update_expert_location(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:

View File

@@ -317,6 +317,13 @@ class DeepseekV2MoE(nn.Module):
def _enable_deepep_moe(self):
return global_server_args_dict["enable_deepep_moe"]
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def op_gate(self, state):
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
@@ -1599,6 +1606,14 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_vc = w_vc.contiguous()
self_attn.use_deep_gemm_bmm = True
# TODO support nextn later
if not is_nextn:
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):