Files
sglang/python/sglang/srt/managers/expert_location.py

395 lines
14 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers import deepseek_eplb
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@property
def num_layers(self) -> int:
return self.physical_to_logical_map.shape[0]
@property
def num_physical_experts(self) -> int:
return self.physical_to_logical_map.shape[1]
@property
def num_local_physical_experts(self) -> int:
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
assert remainder == 0
return ans
@property
def num_logical_experts(self) -> int:
return self.logical_to_all_physical_map.shape[1]
@property
def ep_size(self):
# TODO change when EP size != world size
return torch.distributed.get_world_size()
def __post_init__(self):
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
self.logical_to_all_physical_map.shape
)
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
num_layers_3, num_logical_experts_2 = (
self.logical_to_rank_dispatch_physical_map.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
num_logical_experts = model_config_for_expert_location.num_logical_experts
physical_to_logical_map = (
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
% num_logical_experts
)
return ExpertLocationMetadata.init_by_mapping(
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
)
@staticmethod
def init_by_mapping(
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def init_by_eplb(
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
):
if not isinstance(logical_count, torch.Tensor):
logical_count = torch.tensor(logical_count)
if len(logical_count.shape) == 2:
logical_count = logical_count.unsqueeze(0)
logical_count = logical_count.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
phase = server_args.disaggregation_mode
if phase == "null":
phase = "decode"
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
deepseek_eplb.rebalance_experts(
tokens_per_expert=logical_count,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_physical_experts // common["ep_size"],
num_groups=model_config_for_expert_location.num_groups,
num_nodes=server_args.nnodes,
phase=phase,
)
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = (
ModelConfigForExpertLocation.from_model_config(model_config)
)
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
+ server_args.ep_num_redundant_experts
)
ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0
num_local_physical_experts = num_physical_experts // ep_size
return dict(
model_config_for_expert_location=model_config_for_expert_location,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
ep_size=ep_size,
)
@staticmethod
def _init_raw(
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
):
_, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map_padded = F.pad(
logical_to_all_physical_map,
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
value=-1,
)
logical_to_all_physical_map_num_valid = torch.count_nonzero(
logical_to_all_physical_map != -1, dim=-1
)
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank(),
),
)
# -------------------------------- mutation ------------------------------------
def update(
self,
other: "ExpertLocationMetadata",
):
for field in [
"ep_size",
]:
assert getattr(self, field) == getattr(other, field)
for field in [
"physical_to_logical_map",
"logical_to_all_physical_map",
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
dst = getattr(self, field)
dst[...] = getattr(other, field)
# -------------------------------- usage ------------------------------------
def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
def get_global_expert_location_metadata():
return _global_expert_location_metadata
def set_global_expert_location_metadata(value):
global _global_expert_location_metadata
assert _global_expert_location_metadata is None
_global_expert_location_metadata = value
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
layer_id, physical_expert_id
].item()
logical_to_all_physical_map[layer_id][logical_expert_id].append(
physical_expert_id
)
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
return torch.tensor(
logical_to_all_physical_map, device=physical_to_logical_map.device
)
def _pad_nested_array(arr, pad_value):
max_len = max(len(inner) for outer in arr for inner in outer)
padded = [
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
for outer in arr
]
return padded
# TODO use more sophisticated approaches
def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor,
logical_to_all_physical_map_num_valid: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
ep_rank: int,
base_seed: int = 42,
):
device = logical_to_all_physical_map.device
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
g = torch.Generator(device=device)
g.manual_seed(base_seed + ep_rank)
output_shape = (num_layers, num_logical_experts)
chosen_index = (
torch.randint(
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
)
% logical_to_all_physical_map_num_valid
)
logical_to_rank_dispatch_physical_map = torch.gather(
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
).squeeze(-1)
assert logical_to_rank_dispatch_physical_map.shape == output_shape
for index in range(logical_to_all_physical_map_num_valid.max().item()):
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
is_valid = partial_logical_to_all_physical_map != -1
is_same_gpu = (
partial_logical_to_all_physical_map // num_local_physical_experts
) == ep_rank
logical_to_rank_dispatch_physical_map = torch.where(
is_valid & is_same_gpu,
partial_logical_to_all_physical_map,
logical_to_rank_dispatch_physical_map,
)
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
return logical_to_rank_dispatch_physical_map
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
if hasattr(model_class, "get_model_config_for_expert_location"):
return model_class.get_model_config_for_expert_location(
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function
if data.endswith(".pt"):
data_dict = torch.load(data, weights_only=True)
elif data.endswith(".json"):
data_dict = json.loads(Path(data).read_text())
else:
data_dict = json.loads(data)
if "physical_to_logical_map" in data_dict:
logger.info(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
)
elif "logical_count" in data_dict:
logger.info(
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_eplb(
server_args, model_config, logical_count=data_dict["logical_count"]
)
else:
raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
)