Support updating expert locations dynamically (#6388)
This commit is contained in:
@@ -22,6 +22,7 @@ import torch.distributed
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.managers import deepseek_eplb
|
||||||
from sglang.srt.model_loader import get_model_architecture
|
from sglang.srt.model_loader import get_model_architecture
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
@@ -207,6 +208,26 @@ class ExpertLocationMetadata:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -------------------------------- mutation ------------------------------------
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
other: "ExpertLocationMetadata",
|
||||||
|
):
|
||||||
|
for field in [
|
||||||
|
"ep_size",
|
||||||
|
]:
|
||||||
|
assert getattr(self, field) == getattr(other, field)
|
||||||
|
|
||||||
|
for field in [
|
||||||
|
"physical_to_logical_map",
|
||||||
|
"logical_to_all_physical_map",
|
||||||
|
"logical_to_all_physical_map_num_valid",
|
||||||
|
"logical_to_rank_dispatch_physical_map",
|
||||||
|
]:
|
||||||
|
dst = getattr(self, field)
|
||||||
|
dst[...] = getattr(other, field)
|
||||||
|
|
||||||
# -------------------------------- usage ------------------------------------
|
# -------------------------------- usage ------------------------------------
|
||||||
|
|
||||||
def logical_to_all_physical(
|
def logical_to_all_physical(
|
||||||
|
|||||||
420
python/sglang/srt/model_executor/expert_location_updater.py
Normal file
420
python/sglang/srt/model_executor/expert_location_updater.py
Normal file
@@ -0,0 +1,420 @@
|
|||||||
|
# Copyright 2023-2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch.distributed import P2POp
|
||||||
|
|
||||||
|
from sglang.srt.managers.expert_location import (
|
||||||
|
ExpertLocationMetadata,
|
||||||
|
get_global_expert_location_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def update_expert_location(
|
||||||
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
nnodes: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
|
_update_expert_weights(
|
||||||
|
routed_experts_weights_of_layer,
|
||||||
|
old_expert_location_metadata,
|
||||||
|
new_expert_location_metadata,
|
||||||
|
nnodes,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_expert_weights(
|
||||||
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
|
old_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
nnodes: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
temp_buffers = create_temp_buffers(
|
||||||
|
next(iter(routed_experts_weights_of_layer.values()))
|
||||||
|
)
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||||
|
num_gpu_per_node = world_size // nnodes
|
||||||
|
|
||||||
|
old_physical_to_logical_map = (
|
||||||
|
old_expert_location_metadata.physical_to_logical_map.tolist()
|
||||||
|
)
|
||||||
|
new_physical_to_logical_map = (
|
||||||
|
new_expert_location_metadata.physical_to_logical_map.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
|
||||||
|
update_expert_weights_single_layer(
|
||||||
|
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
||||||
|
temp_buffers=temp_buffers,
|
||||||
|
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
|
||||||
|
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
|
||||||
|
num_local_physical_experts=num_local_physical_experts,
|
||||||
|
num_gpu_per_node=num_gpu_per_node,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_temp_buffers(sample_tensors):
|
||||||
|
return [torch.empty_like(tensor) for tensor in sample_tensors]
|
||||||
|
|
||||||
|
|
||||||
|
def update_expert_weights_single_layer(
|
||||||
|
routed_experts_weights: List[torch.Tensor],
|
||||||
|
temp_buffers: List[torch.Tensor],
|
||||||
|
old_physical_to_logical_map: List[int], # (num_physical_Experts,)
|
||||||
|
new_physical_to_logical_map: List[int], # (num_physical_Experts,)
|
||||||
|
num_local_physical_experts: int,
|
||||||
|
num_gpu_per_node: int,
|
||||||
|
rank: int,
|
||||||
|
debug: bool = False,
|
||||||
|
):
|
||||||
|
assert all(
|
||||||
|
tensor.shape[0] == num_local_physical_experts
|
||||||
|
for tensor in routed_experts_weights
|
||||||
|
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
|
||||||
|
|
||||||
|
output_logs = [] if debug else None
|
||||||
|
|
||||||
|
num_physical_experts = len(old_physical_to_logical_map)
|
||||||
|
num_tensors = len(routed_experts_weights)
|
||||||
|
|
||||||
|
self_node_id = rank // num_gpu_per_node
|
||||||
|
|
||||||
|
local_expert_location_range = (
|
||||||
|
rank * num_local_physical_experts,
|
||||||
|
(rank + 1) * num_local_physical_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _entrypoint():
|
||||||
|
# List[Tuple[logical_expert_id, List[P2POp]]]
|
||||||
|
p2p_op_infos: List[Tuple[int, List[P2POp]]] = []
|
||||||
|
# List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]
|
||||||
|
buffer2weight_copy_infos: List[Tuple[int, int]] = []
|
||||||
|
|
||||||
|
_handle_recv(buffer2weight_copy_infos, p2p_op_infos)
|
||||||
|
_create_isend_ops(p2p_op_infos)
|
||||||
|
_execute_p2p_ops(p2p_op_infos)
|
||||||
|
_execute_buffer2weight_copies(buffer2weight_copy_infos)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
output_logs.append(f"{p2p_op_infos=}")
|
||||||
|
output_logs.append(f"{buffer2weight_copy_infos=}")
|
||||||
|
|
||||||
|
def _handle_recv(buffer2weight_copy_infos, p2p_op_infos):
|
||||||
|
for dst_expert_location in range(*local_expert_location_range):
|
||||||
|
_handle_recv_of_dst_expert_location(
|
||||||
|
dst_expert_location, buffer2weight_copy_infos, p2p_op_infos
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_recv_of_dst_expert_location(
|
||||||
|
dst_expert_location: int, buffer2weight_copy_infos, p2p_op_infos
|
||||||
|
):
|
||||||
|
logical_expert_id = new_physical_to_logical_map[dst_expert_location]
|
||||||
|
|
||||||
|
# case 1: unchanged
|
||||||
|
if old_physical_to_logical_map[dst_expert_location] == logical_expert_id:
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=unchanged"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# case 2: same-gpu
|
||||||
|
for src_expert_location in range(*local_expert_location_range):
|
||||||
|
if old_physical_to_logical_map[src_expert_location] == logical_expert_id:
|
||||||
|
for i in range(num_tensors):
|
||||||
|
_get_tensor(temp_buffers, i, dst_expert_location).copy_(
|
||||||
|
_get_tensor(routed_experts_weights, i, src_expert_location)
|
||||||
|
)
|
||||||
|
buffer2weight_copy_infos.append(
|
||||||
|
(dst_expert_location, dst_expert_location)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-gpu {src_expert_location=}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# case 3: free-rider
|
||||||
|
for src_expert_location in range(
|
||||||
|
rank * num_local_physical_experts, dst_expert_location
|
||||||
|
):
|
||||||
|
if new_physical_to_logical_map[src_expert_location] == logical_expert_id:
|
||||||
|
buffer2weight_copy_infos.append(
|
||||||
|
(src_expert_location, dst_expert_location)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=free-rider {src_expert_location=}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
|
||||||
|
_compute_comm_info(logical_expert_id=logical_expert_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# case 4: same-node
|
||||||
|
if rank in need_comm_self_node_dst_ranks:
|
||||||
|
chosen_src_rank = same_node_mapping.chunk_value_from_element_value(
|
||||||
|
element_value=rank
|
||||||
|
)
|
||||||
|
_create_p2p_recv_and_buffer2weight_copy(
|
||||||
|
buffer2weight_copy_infos,
|
||||||
|
p2p_op_infos,
|
||||||
|
src_rank=chosen_src_rank,
|
||||||
|
logical_expert_id=logical_expert_id,
|
||||||
|
dst_expert_location=dst_expert_location,
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=same-node {chosen_src_rank=}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# case 5: cross-node
|
||||||
|
# Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert
|
||||||
|
chosen_src_rank = cross_node_mapping.chunk_value_from_element_value(
|
||||||
|
element_value=rank
|
||||||
|
)
|
||||||
|
_create_p2p_recv_and_buffer2weight_copy(
|
||||||
|
buffer2weight_copy_infos,
|
||||||
|
p2p_op_infos,
|
||||||
|
src_rank=chosen_src_rank,
|
||||||
|
logical_expert_id=logical_expert_id,
|
||||||
|
dst_expert_location=dst_expert_location,
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"handle_recv_of_dst_expert_location {dst_expert_location=} case=cross-node {chosen_src_rank=}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _create_p2p_recv_and_buffer2weight_copy(
|
||||||
|
buffer2weight_copy_infos,
|
||||||
|
p2p_op_infos,
|
||||||
|
*,
|
||||||
|
logical_expert_id: int,
|
||||||
|
src_rank: int,
|
||||||
|
dst_expert_location: int,
|
||||||
|
):
|
||||||
|
p2p_op_infos.append(
|
||||||
|
(
|
||||||
|
logical_expert_id,
|
||||||
|
[
|
||||||
|
P2POp(
|
||||||
|
op=torch.distributed.irecv,
|
||||||
|
tensor=_get_tensor(temp_buffers, i, dst_expert_location),
|
||||||
|
peer=src_rank,
|
||||||
|
)
|
||||||
|
for i in range(num_tensors)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
buffer2weight_copy_infos.append((dst_expert_location, dst_expert_location))
|
||||||
|
|
||||||
|
def _create_isend_ops(p2p_op_infos):
|
||||||
|
handled_logical_expert_ids = set()
|
||||||
|
for src_expert_location in range(*local_expert_location_range):
|
||||||
|
logical_expert_id = old_physical_to_logical_map[src_expert_location]
|
||||||
|
|
||||||
|
if logical_expert_id in handled_logical_expert_ids:
|
||||||
|
continue
|
||||||
|
handled_logical_expert_ids.add(logical_expert_id)
|
||||||
|
|
||||||
|
_create_isend_ops_of_logical_expert_id(
|
||||||
|
logical_expert_id, src_expert_location, p2p_op_infos
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_isend_ops_of_logical_expert_id(
|
||||||
|
logical_expert_id, src_expert_location, p2p_op_infos
|
||||||
|
):
|
||||||
|
same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks = (
|
||||||
|
_compute_comm_info(logical_expert_id=logical_expert_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
same_node_dst_ranks = same_node_mapping.element_values_from_chunk_value(
|
||||||
|
chunk_value=rank
|
||||||
|
)
|
||||||
|
cross_node_dst_ranks = cross_node_mapping.element_values_from_chunk_value(
|
||||||
|
chunk_value=rank
|
||||||
|
)
|
||||||
|
all_dst_ranks = same_node_dst_ranks + cross_node_dst_ranks
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
output_logs.append(
|
||||||
|
f"create_isend_ops_of_logical_expert_id {logical_expert_id=} {src_expert_location=} {same_node_dst_ranks=} {cross_node_dst_ranks=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
p2p_op_infos.append(
|
||||||
|
(
|
||||||
|
logical_expert_id,
|
||||||
|
[
|
||||||
|
P2POp(
|
||||||
|
op=torch.distributed.isend,
|
||||||
|
tensor=_get_tensor(
|
||||||
|
routed_experts_weights, i, src_expert_location
|
||||||
|
),
|
||||||
|
peer=dst_rank,
|
||||||
|
)
|
||||||
|
for dst_rank in all_dst_ranks
|
||||||
|
for i in range(num_tensors)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_comm_info(logical_expert_id: int):
|
||||||
|
all_src_ranks = _deduplicate_ordered(
|
||||||
|
[
|
||||||
|
x // num_local_physical_experts
|
||||||
|
for x in range(num_physical_experts)
|
||||||
|
if old_physical_to_logical_map[x] == logical_expert_id
|
||||||
|
]
|
||||||
|
)
|
||||||
|
all_src_nodes = [x // num_gpu_per_node for x in all_src_ranks]
|
||||||
|
self_node_src_ranks = [
|
||||||
|
x for x in all_src_ranks if x // num_gpu_per_node == self_node_id
|
||||||
|
]
|
||||||
|
|
||||||
|
need_comm_dst_ranks = _deduplicate_ordered(
|
||||||
|
[
|
||||||
|
x // num_local_physical_experts
|
||||||
|
for x in range(num_physical_experts)
|
||||||
|
if new_physical_to_logical_map[x] == logical_expert_id
|
||||||
|
and x // num_local_physical_experts not in all_src_ranks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
need_comm_self_node_dst_ranks = (
|
||||||
|
[x for x in need_comm_dst_ranks if x // num_gpu_per_node == self_node_id]
|
||||||
|
if len(self_node_src_ranks) > 0
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
need_comm_cross_node_dst_ranks = [
|
||||||
|
x
|
||||||
|
for x in need_comm_dst_ranks
|
||||||
|
if (x // num_gpu_per_node) not in all_src_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
same_node_mapping = _ChunkUtils(
|
||||||
|
chunk_values=self_node_src_ranks,
|
||||||
|
element_values=need_comm_self_node_dst_ranks,
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_node_mapping = _ChunkUtils(
|
||||||
|
chunk_values=all_src_ranks,
|
||||||
|
element_values=need_comm_cross_node_dst_ranks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return same_node_mapping, cross_node_mapping, need_comm_self_node_dst_ranks
|
||||||
|
|
||||||
|
def _execute_p2p_ops(p2p_op_infos):
|
||||||
|
sorted_infos = sorted(p2p_op_infos, key=lambda info: info[0])
|
||||||
|
p2p_ops = [op for _, ops in sorted_infos for op in ops]
|
||||||
|
if len(p2p_ops) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
|
||||||
|
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
|
||||||
|
for (
|
||||||
|
temp_buffers_expert_location,
|
||||||
|
routed_experts_weights_expert_location,
|
||||||
|
) in buffer2weight_copy_infos:
|
||||||
|
for i in range(num_tensors):
|
||||||
|
_get_tensor(
|
||||||
|
routed_experts_weights, i, routed_experts_weights_expert_location
|
||||||
|
).copy_(_get_tensor(temp_buffers, i, temp_buffers_expert_location))
|
||||||
|
|
||||||
|
def _get_tensor(tensors, tensor_index: int, expert_location: int) -> torch.Tensor:
|
||||||
|
return tensors[tensor_index][_get_local_expert_location(expert_location)]
|
||||||
|
|
||||||
|
def _get_local_expert_location(expert_location: int) -> int:
|
||||||
|
assert (
|
||||||
|
local_expert_location_range[0]
|
||||||
|
<= expert_location
|
||||||
|
< local_expert_location_range[1]
|
||||||
|
)
|
||||||
|
return expert_location % num_local_physical_experts
|
||||||
|
|
||||||
|
_entrypoint()
|
||||||
|
|
||||||
|
return output_logs
|
||||||
|
|
||||||
|
|
||||||
|
class _ChunkUtils:
|
||||||
|
def __init__(self, *, chunk_values: List, element_values: List):
|
||||||
|
self.chunk_values = chunk_values
|
||||||
|
self.element_values = element_values
|
||||||
|
|
||||||
|
def chunk_value_from_element_value(self, element_value):
|
||||||
|
chunk_index = self._chunk_index_from_element_index(
|
||||||
|
num_elements=len(self.element_values),
|
||||||
|
num_chunks=len(self.chunk_values),
|
||||||
|
element_index=self.element_values.index(element_value),
|
||||||
|
)
|
||||||
|
return self.chunk_values[chunk_index]
|
||||||
|
|
||||||
|
def element_values_from_chunk_value(self, chunk_value) -> List:
|
||||||
|
if len(self.element_values) == 0:
|
||||||
|
return []
|
||||||
|
element_slice = self._element_slice_from_chunk_index(
|
||||||
|
num_elements=len(self.element_values),
|
||||||
|
num_chunks=len(self.chunk_values),
|
||||||
|
chunk_index=self.chunk_values.index(chunk_value),
|
||||||
|
)
|
||||||
|
return self.element_values[element_slice]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _chunk_index_from_element_index(
|
||||||
|
num_elements: int, num_chunks: int, element_index: int
|
||||||
|
) -> int:
|
||||||
|
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
|
||||||
|
num_elements_for_long_chunks = num_long_chunks * (short_chunk_size + 1)
|
||||||
|
if element_index < num_elements_for_long_chunks:
|
||||||
|
return element_index // (short_chunk_size + 1)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
num_long_chunks
|
||||||
|
+ (element_index - num_elements_for_long_chunks) // short_chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _element_slice_from_chunk_index(
|
||||||
|
num_elements: int, num_chunks: int, chunk_index: int
|
||||||
|
) -> slice:
|
||||||
|
short_chunk_size, num_long_chunks = divmod(num_elements, num_chunks)
|
||||||
|
start = chunk_index * short_chunk_size + min(chunk_index, num_long_chunks)
|
||||||
|
end = start + short_chunk_size + int(chunk_index < num_long_chunks)
|
||||||
|
return slice(start, end)
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate_ordered(arr: List[int]):
|
||||||
|
output = []
|
||||||
|
for item in arr:
|
||||||
|
if len(output) == 0 or item != output[-1]:
|
||||||
|
output.append(item)
|
||||||
|
return output
|
||||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.expert_distribution import (
|
|||||||
set_global_expert_distribution_recorder,
|
set_global_expert_distribution_recorder,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_location import (
|
from sglang.srt.managers.expert_location import (
|
||||||
|
ExpertLocationMetadata,
|
||||||
compute_initial_expert_location_metadata,
|
compute_initial_expert_location_metadata,
|
||||||
get_global_expert_location_metadata,
|
get_global_expert_location_metadata,
|
||||||
set_global_expert_location_metadata,
|
set_global_expert_location_metadata,
|
||||||
@@ -70,6 +71,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
TokenToKVPoolAllocator,
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.model_executor import expert_location_updater
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
@@ -575,6 +577,16 @@ class ModelRunner:
|
|||||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
def update_expert_location(
|
||||||
|
self, new_expert_location_metadata: ExpertLocationMetadata
|
||||||
|
):
|
||||||
|
expert_location_updater.update_expert_location(
|
||||||
|
self.model.routed_experts_weights_of_layer,
|
||||||
|
new_expert_location_metadata,
|
||||||
|
nnodes=self.server_args.nnodes,
|
||||||
|
rank=self.tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
def update_weights_from_disk(
|
def update_weights_from_disk(
|
||||||
self, model_path: str, load_format: str
|
self, model_path: str, load_format: str
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
|
|||||||
@@ -317,6 +317,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
def _enable_deepep_moe(self):
|
def _enable_deepep_moe(self):
|
||||||
return global_server_args_dict["enable_deepep_moe"]
|
return global_server_args_dict["enable_deepep_moe"]
|
||||||
|
|
||||||
|
def get_moe_weights(self):
|
||||||
|
return [
|
||||||
|
x.data
|
||||||
|
for name, x in self.experts.named_parameters()
|
||||||
|
if name not in ["correction_bias"]
|
||||||
|
]
|
||||||
|
|
||||||
def op_gate(self, state):
|
def op_gate(self, state):
|
||||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||||
@@ -1599,6 +1606,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn.w_vc = w_vc.contiguous()
|
self_attn.w_vc = w_vc.contiguous()
|
||||||
self_attn.use_deep_gemm_bmm = True
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
|
# TODO support nextn later
|
||||||
|
if not is_nextn:
|
||||||
|
self.routed_experts_weights_of_layer = {
|
||||||
|
layer_id: layer.mlp.get_moe_weights()
|
||||||
|
for layer_id, layer in enumerate(self.model.layers)
|
||||||
|
if isinstance(layer.mlp, DeepseekV2MoE)
|
||||||
|
}
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||||
if is_nextn:
|
if is_nextn:
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
|
|||||||
255
test/srt/test_expert_location_updater.py
Normal file
255
test/srt/test_expert_location_updater.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.multiprocessing import Process
|
||||||
|
|
||||||
|
from sglang.srt.model_executor import expert_location_updater
|
||||||
|
from sglang.test.test_utils import CustomTestCase, find_available_port
|
||||||
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _TestInfo:
|
||||||
|
nnodes: int
|
||||||
|
num_logical_experts: int
|
||||||
|
num_physical_experts: int
|
||||||
|
num_repeat: int = 5000
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpertLocationUpdater(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
def test_cpu(self):
|
||||||
|
self._test_common(device="cpu")
|
||||||
|
self._test_core(
|
||||||
|
num_gpus=32,
|
||||||
|
device="cpu",
|
||||||
|
infos=[
|
||||||
|
_TestInfo(
|
||||||
|
nnodes=4,
|
||||||
|
num_logical_experts=256,
|
||||||
|
num_physical_experts=288,
|
||||||
|
num_repeat=10000,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cpu_slow(self):
|
||||||
|
if is_in_ci():
|
||||||
|
return
|
||||||
|
self._test_core(
|
||||||
|
num_gpus=144,
|
||||||
|
device="cpu",
|
||||||
|
infos=[
|
||||||
|
_TestInfo(
|
||||||
|
nnodes=18,
|
||||||
|
num_logical_experts=256,
|
||||||
|
num_physical_experts=288,
|
||||||
|
num_repeat=10000,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gpu(self):
|
||||||
|
if is_in_ci():
|
||||||
|
return
|
||||||
|
self._test_common(device="cuda")
|
||||||
|
|
||||||
|
def _test_common(self, device):
|
||||||
|
infos = []
|
||||||
|
|
||||||
|
for nnodes in [1, 2, 4]:
|
||||||
|
for num_logical_experts in [2, 5, 20, 256]:
|
||||||
|
for num_physical_experts in [8, 16, 256, 288]:
|
||||||
|
if num_logical_experts > num_physical_experts:
|
||||||
|
continue
|
||||||
|
infos.append(
|
||||||
|
_TestInfo(
|
||||||
|
nnodes=nnodes,
|
||||||
|
num_logical_experts=num_logical_experts,
|
||||||
|
num_physical_experts=num_physical_experts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._test_core(num_gpus=8, device=device, infos=infos)
|
||||||
|
|
||||||
|
def _test_core(
|
||||||
|
self,
|
||||||
|
num_gpus: int,
|
||||||
|
device: str,
|
||||||
|
infos: List[_TestInfo],
|
||||||
|
):
|
||||||
|
master_port = find_available_port(23456)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
output_reader, output_writer = mp.Pipe(duplex=False)
|
||||||
|
for rank in range(num_gpus):
|
||||||
|
p = Process(
|
||||||
|
target=_run_subprocess,
|
||||||
|
kwargs=dict(
|
||||||
|
rank=rank,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
output_writer=output_writer,
|
||||||
|
master_port=master_port,
|
||||||
|
device=device,
|
||||||
|
infos=infos,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
for _ in range(num_gpus):
|
||||||
|
self.assertTrue(
|
||||||
|
output_reader.recv(), f"Subprocess has error, please see logs above."
|
||||||
|
)
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_subprocess(
|
||||||
|
rank: int,
|
||||||
|
num_gpus: int,
|
||||||
|
master_port: int,
|
||||||
|
device: str,
|
||||||
|
infos: List[_TestInfo],
|
||||||
|
output_writer,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = str(master_port)
|
||||||
|
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
rank=rank,
|
||||||
|
world_size=num_gpus,
|
||||||
|
backend={"cpu": "gloo", "cuda": None}[device],
|
||||||
|
)
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.set_device(f"cuda:{rank}")
|
||||||
|
|
||||||
|
for info in infos:
|
||||||
|
_execute_test(info, rank=rank, num_gpus=num_gpus, device=device)
|
||||||
|
|
||||||
|
execution_ok = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"subprocess[{rank=}] has error: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
execution_ok = False
|
||||||
|
|
||||||
|
output_writer.send(execution_ok)
|
||||||
|
output_writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str):
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Test: {num_gpus=} {info=}", flush=True)
|
||||||
|
|
||||||
|
assert info.num_physical_experts % num_gpus == 0
|
||||||
|
num_local_physical_experts = info.num_physical_experts // num_gpus
|
||||||
|
assert num_gpus % info.nnodes == 0
|
||||||
|
num_gpu_per_node = num_gpus // info.nnodes
|
||||||
|
|
||||||
|
def _create_routed_experts_weights(physical_to_logical_map):
|
||||||
|
local_logical_expert_ids = physical_to_logical_map[
|
||||||
|
rank * num_local_physical_experts : (rank + 1) * num_local_physical_experts
|
||||||
|
].cpu()
|
||||||
|
return [
|
||||||
|
local_logical_expert_ids.to(device).clone(),
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[local_logical_expert_id * 10, local_logical_expert_id * 100]
|
||||||
|
for local_logical_expert_id in local_logical_expert_ids.tolist()
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_physical_to_logical_map():
|
||||||
|
if rank == 0:
|
||||||
|
ans = torch.concat(
|
||||||
|
[
|
||||||
|
torch.arange(0, info.num_logical_experts),
|
||||||
|
torch.randint(
|
||||||
|
0,
|
||||||
|
info.num_logical_experts,
|
||||||
|
(info.num_physical_experts - info.num_logical_experts,),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ans = ans[torch.randperm(ans.shape[0])]
|
||||||
|
else:
|
||||||
|
ans = torch.empty((info.num_physical_experts,), dtype=torch.int64)
|
||||||
|
|
||||||
|
assert ans.dtype == torch.int64 and ans.shape == (info.num_physical_experts,)
|
||||||
|
ans = ans.to(device)
|
||||||
|
torch.distributed.broadcast(ans, src=0)
|
||||||
|
|
||||||
|
return ans.cpu()
|
||||||
|
|
||||||
|
physical_to_logical_map = _create_physical_to_logical_map()
|
||||||
|
routed_experts_weights = _create_routed_experts_weights(physical_to_logical_map)
|
||||||
|
|
||||||
|
for i in range(info.num_repeat):
|
||||||
|
if rank == 0 and ((i % 500 == 0) or (i == info.num_repeat - 1)):
|
||||||
|
print(f"Step {i}/{info.num_repeat}", flush=True)
|
||||||
|
|
||||||
|
new_physical_to_logical_map = _create_physical_to_logical_map()
|
||||||
|
expect_new_weights = _create_routed_experts_weights(new_physical_to_logical_map)
|
||||||
|
|
||||||
|
output_logs = expert_location_updater.update_expert_weights_single_layer(
|
||||||
|
routed_experts_weights=routed_experts_weights,
|
||||||
|
temp_buffers=expert_location_updater.create_temp_buffers(
|
||||||
|
routed_experts_weights
|
||||||
|
),
|
||||||
|
old_physical_to_logical_map=physical_to_logical_map,
|
||||||
|
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||||
|
num_local_physical_experts=num_local_physical_experts,
|
||||||
|
num_gpu_per_node=num_gpu_per_node,
|
||||||
|
rank=rank,
|
||||||
|
debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_has_error = not all(
|
||||||
|
torch.all(x == y)
|
||||||
|
for x, y in zip(routed_experts_weights, expect_new_weights, strict=True)
|
||||||
|
)
|
||||||
|
global_has_error = torch.tensor(local_has_error, device=device)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
global_has_error, op=torch.distributed.ReduceOp.MAX
|
||||||
|
)
|
||||||
|
|
||||||
|
if global_has_error.cpu().item():
|
||||||
|
output_logs_str = "\n".join(output_logs)
|
||||||
|
local_message = (
|
||||||
|
f"===================== rank {rank} ============================\n"
|
||||||
|
f"{num_gpus=} {info=}\n"
|
||||||
|
f"{routed_experts_weights[0].tolist()=}\n"
|
||||||
|
f"{expect_new_weights[0].tolist()=}\n"
|
||||||
|
f"{physical_to_logical_map.tolist()=}\n"
|
||||||
|
f"{new_physical_to_logical_map.tolist()=}\n"
|
||||||
|
f"===logs===\n"
|
||||||
|
f"{output_logs_str}\n"
|
||||||
|
f"==============================================================\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
global_messages = ([None] * num_gpus) if rank == 0 else None
|
||||||
|
torch.distributed.gather_object(local_message, global_messages, dst=0)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("\n\n".join(global_messages), flush=True)
|
||||||
|
raise AssertionError(f"Error happens, see logs above")
|
||||||
|
|
||||||
|
physical_to_logical_map = new_physical_to_logical_map
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user