395 lines
14 KiB
Python
395 lines
14 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn.functional as F
|
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
from sglang.srt.managers import deepseek_eplb
|
|
from sglang.srt.model_loader import get_model_architecture
|
|
from sglang.srt.server_args import ServerArgs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ExpertLocationMetadata:
|
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
|
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
|
|
|
|
# -------------------------------- properties ------------------------------------
|
|
|
|
@property
|
|
def num_layers(self) -> int:
|
|
return self.physical_to_logical_map.shape[0]
|
|
|
|
@property
|
|
def num_physical_experts(self) -> int:
|
|
return self.physical_to_logical_map.shape[1]
|
|
|
|
@property
|
|
def num_local_physical_experts(self) -> int:
|
|
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
|
|
assert remainder == 0
|
|
return ans
|
|
|
|
@property
|
|
def num_logical_experts(self) -> int:
|
|
return self.logical_to_all_physical_map.shape[1]
|
|
|
|
@property
|
|
def ep_size(self):
|
|
# TODO change when EP size != world size
|
|
return torch.distributed.get_world_size()
|
|
|
|
def __post_init__(self):
|
|
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
|
|
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
|
|
self.logical_to_all_physical_map.shape
|
|
)
|
|
num_layers_2, num_logical_experts_1 = (
|
|
self.logical_to_all_physical_map_num_valid.shape
|
|
)
|
|
num_layers_3, num_logical_experts_2 = (
|
|
self.logical_to_rank_dispatch_physical_map.shape
|
|
)
|
|
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
|
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
|
assert num_physical_experts_0 == num_physical_experts_1
|
|
|
|
# -------------------------------- construction ------------------------------------
|
|
|
|
@staticmethod
|
|
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
|
"""Trivial location - logical expert i corresponds to physical expert i"""
|
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
|
num_physical_experts = common["num_physical_experts"]
|
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
|
num_layers = model_config_for_expert_location.num_layers
|
|
num_logical_experts = model_config_for_expert_location.num_logical_experts
|
|
|
|
physical_to_logical_map = (
|
|
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
|
|
% num_logical_experts
|
|
)
|
|
|
|
return ExpertLocationMetadata.init_by_mapping(
|
|
server_args,
|
|
model_config,
|
|
physical_to_logical_map=physical_to_logical_map,
|
|
)
|
|
|
|
@staticmethod
|
|
def init_by_mapping(
|
|
server_args: ServerArgs,
|
|
model_config: ModelConfig,
|
|
physical_to_logical_map,
|
|
):
|
|
if not isinstance(physical_to_logical_map, torch.Tensor):
|
|
physical_to_logical_map = torch.tensor(physical_to_logical_map)
|
|
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
|
|
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
|
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
|
physical_to_logical_map,
|
|
num_logical_experts=model_config_for_expert_location.num_logical_experts,
|
|
)
|
|
|
|
return ExpertLocationMetadata._init_raw(
|
|
ep_size=common["ep_size"],
|
|
physical_to_logical_map=physical_to_logical_map,
|
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
|
)
|
|
|
|
@staticmethod
|
|
def init_by_eplb(
|
|
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
|
|
):
|
|
if not isinstance(logical_count, torch.Tensor):
|
|
logical_count = torch.tensor(logical_count)
|
|
if len(logical_count.shape) == 2:
|
|
logical_count = logical_count.unsqueeze(0)
|
|
logical_count = logical_count.to(server_args.device)
|
|
|
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
|
num_physical_experts = common["num_physical_experts"]
|
|
|
|
phase = server_args.disaggregation_mode
|
|
if phase == "null":
|
|
phase = "decode"
|
|
|
|
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
|
deepseek_eplb.rebalance_experts(
|
|
tokens_per_expert=logical_count,
|
|
num_physical_experts=num_physical_experts,
|
|
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
|
num_groups=model_config_for_expert_location.num_groups,
|
|
num_nodes=server_args.nnodes,
|
|
phase=phase,
|
|
)
|
|
)
|
|
|
|
return ExpertLocationMetadata._init_raw(
|
|
ep_size=common["ep_size"],
|
|
physical_to_logical_map=physical_to_logical_map,
|
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
|
)
|
|
|
|
@staticmethod
|
|
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
|
|
model_config_for_expert_location = (
|
|
ModelConfigForExpertLocation.from_model_config(model_config)
|
|
)
|
|
|
|
num_physical_experts = (
|
|
model_config_for_expert_location.num_logical_experts
|
|
+ server_args.ep_num_redundant_experts
|
|
)
|
|
ep_size = server_args.ep_size
|
|
assert num_physical_experts % ep_size == 0
|
|
num_local_physical_experts = num_physical_experts // ep_size
|
|
|
|
return dict(
|
|
model_config_for_expert_location=model_config_for_expert_location,
|
|
num_physical_experts=num_physical_experts,
|
|
num_local_physical_experts=num_local_physical_experts,
|
|
ep_size=ep_size,
|
|
)
|
|
|
|
@staticmethod
|
|
def _init_raw(
|
|
ep_size: int,
|
|
physical_to_logical_map: torch.Tensor,
|
|
logical_to_all_physical_map: torch.Tensor,
|
|
):
|
|
_, num_physical_experts = physical_to_logical_map.shape
|
|
|
|
logical_to_all_physical_map_padded = F.pad(
|
|
logical_to_all_physical_map,
|
|
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
|
|
value=-1,
|
|
)
|
|
|
|
logical_to_all_physical_map_num_valid = torch.count_nonzero(
|
|
logical_to_all_physical_map != -1, dim=-1
|
|
)
|
|
|
|
return ExpertLocationMetadata(
|
|
physical_to_logical_map=physical_to_logical_map,
|
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
|
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
|
num_gpus=ep_size,
|
|
num_physical_experts=num_physical_experts,
|
|
ep_rank=torch.distributed.get_rank(),
|
|
),
|
|
)
|
|
|
|
# -------------------------------- mutation ------------------------------------
|
|
|
|
def update(
|
|
self,
|
|
other: "ExpertLocationMetadata",
|
|
):
|
|
for field in [
|
|
"ep_size",
|
|
]:
|
|
assert getattr(self, field) == getattr(other, field)
|
|
|
|
for field in [
|
|
"physical_to_logical_map",
|
|
"logical_to_all_physical_map",
|
|
"logical_to_all_physical_map_num_valid",
|
|
"logical_to_rank_dispatch_physical_map",
|
|
]:
|
|
dst = getattr(self, field)
|
|
dst[...] = getattr(other, field)
|
|
|
|
# -------------------------------- usage ------------------------------------
|
|
|
|
def logical_to_all_physical(
|
|
self, layer_id: int, logical_expert_id: int
|
|
) -> List[int]:
|
|
return [
|
|
physical_expert_id
|
|
for physical_expert_id in self.logical_to_all_physical_map[
|
|
layer_id, logical_expert_id
|
|
].tolist()
|
|
if physical_expert_id != -1
|
|
]
|
|
|
|
|
|
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
|
|
|
|
|
def get_global_expert_location_metadata():
|
|
return _global_expert_location_metadata
|
|
|
|
|
|
def set_global_expert_location_metadata(value):
|
|
global _global_expert_location_metadata
|
|
assert _global_expert_location_metadata is None
|
|
_global_expert_location_metadata = value
|
|
|
|
|
|
def _compute_logical_to_all_physical_map(
|
|
physical_to_logical_map: torch.Tensor, num_logical_experts: int
|
|
):
|
|
# This is rarely called, so we use for loops for maximum clarity
|
|
|
|
num_layers, num_physical_experts = physical_to_logical_map.shape
|
|
|
|
logical_to_all_physical_map = [
|
|
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
|
|
]
|
|
for layer_id in range(num_layers):
|
|
for physical_expert_id in range(num_physical_experts):
|
|
logical_expert_id = physical_to_logical_map[
|
|
layer_id, physical_expert_id
|
|
].item()
|
|
logical_to_all_physical_map[layer_id][logical_expert_id].append(
|
|
physical_expert_id
|
|
)
|
|
|
|
logical_to_all_physical_map = _pad_nested_array(
|
|
logical_to_all_physical_map, pad_value=-1
|
|
)
|
|
|
|
return torch.tensor(
|
|
logical_to_all_physical_map, device=physical_to_logical_map.device
|
|
)
|
|
|
|
|
|
def _pad_nested_array(arr, pad_value):
|
|
max_len = max(len(inner) for outer in arr for inner in outer)
|
|
padded = [
|
|
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
|
|
for outer in arr
|
|
]
|
|
return padded
|
|
|
|
|
|
# TODO use more sophisticated approaches
|
|
def compute_logical_to_rank_dispatch_physical_map(
|
|
logical_to_all_physical_map: torch.Tensor,
|
|
logical_to_all_physical_map_num_valid: torch.Tensor,
|
|
num_gpus: int,
|
|
num_physical_experts: int,
|
|
ep_rank: int,
|
|
base_seed: int = 42,
|
|
):
|
|
device = logical_to_all_physical_map.device
|
|
|
|
num_local_physical_experts = num_physical_experts // num_gpus
|
|
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
|
|
|
g = torch.Generator(device=device)
|
|
g.manual_seed(base_seed + ep_rank)
|
|
|
|
output_shape = (num_layers, num_logical_experts)
|
|
chosen_index = (
|
|
torch.randint(
|
|
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
|
|
)
|
|
% logical_to_all_physical_map_num_valid
|
|
)
|
|
logical_to_rank_dispatch_physical_map = torch.gather(
|
|
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
assert logical_to_rank_dispatch_physical_map.shape == output_shape
|
|
|
|
for index in range(logical_to_all_physical_map_num_valid.max().item()):
|
|
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
|
|
is_valid = partial_logical_to_all_physical_map != -1
|
|
is_same_gpu = (
|
|
partial_logical_to_all_physical_map // num_local_physical_experts
|
|
) == ep_rank
|
|
logical_to_rank_dispatch_physical_map = torch.where(
|
|
is_valid & is_same_gpu,
|
|
partial_logical_to_all_physical_map,
|
|
logical_to_rank_dispatch_physical_map,
|
|
)
|
|
|
|
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
|
return logical_to_rank_dispatch_physical_map
|
|
|
|
|
|
@dataclass
|
|
class ModelConfigForExpertLocation:
|
|
num_layers: int
|
|
num_logical_experts: int
|
|
num_groups: Optional[int] = None
|
|
|
|
@staticmethod
|
|
def init_dummy():
|
|
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
|
|
|
@staticmethod
|
|
def from_model_config(model_config: ModelConfig):
|
|
model_class, _ = get_model_architecture(model_config)
|
|
if hasattr(model_class, "get_model_config_for_expert_location"):
|
|
return model_class.get_model_config_for_expert_location(
|
|
model_config.hf_config
|
|
)
|
|
else:
|
|
return ModelConfigForExpertLocation.init_dummy()
|
|
|
|
|
|
def compute_initial_expert_location_metadata(
|
|
server_args: ServerArgs, model_config: ModelConfig
|
|
) -> ExpertLocationMetadata:
|
|
data = server_args.init_expert_location
|
|
if data == "trivial":
|
|
logger.info("init_expert_location from trivial")
|
|
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
|
|
|
# TODO unify with the utils function
|
|
if data.endswith(".pt"):
|
|
data_dict = torch.load(data, weights_only=True)
|
|
elif data.endswith(".json"):
|
|
data_dict = json.loads(Path(data).read_text())
|
|
else:
|
|
data_dict = json.loads(data)
|
|
|
|
if "physical_to_logical_map" in data_dict:
|
|
logger.info(
|
|
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
|
|
)
|
|
return ExpertLocationMetadata.init_by_mapping(
|
|
server_args, model_config, **data_dict
|
|
)
|
|
elif "logical_count" in data_dict:
|
|
logger.info(
|
|
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
|
|
)
|
|
return ExpertLocationMetadata.init_by_eplb(
|
|
server_args, model_config, logical_count=data_dict["logical_count"]
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
|
|
)
|