[2/N] Added the core structure of elastic EP and the eplb algorithm with faulty rank (#10606)
Co-authored-by: Xun Sun <UNIDY2002@outlook.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
74
python/sglang/srt/elastic_ep/elastic_ep.py
Normal file
74
python/sglang/srt/elastic_ep/elastic_ep.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ServerArgs
|
||||
from sglang.srt.utils import is_cpu, is_cuda
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElasticEPState:
|
||||
active_ranks: Optional[torch.Tensor]
|
||||
last_active_ranks: Optional[torch.Tensor]
|
||||
active_ranks_cpu: Optional[torch.Tensor]
|
||||
|
||||
def is_active_equal_last(self) -> bool:
|
||||
return torch.equal(self.active_ranks, self.last_active_ranks)
|
||||
|
||||
def sync_active_to_cpu(self):
|
||||
if self.active_ranks is not None:
|
||||
self.active_ranks_cpu = self.active_ranks.detach().cpu().clone()
|
||||
|
||||
def snapshot_active_to_last(self):
|
||||
if self.active_ranks is not None:
|
||||
self.last_active_ranks = self.active_ranks.clone()
|
||||
|
||||
|
||||
class ElasticEPStateManager:
|
||||
_instance: Optional[ElasticEPState] = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> ElasticEPState:
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def init(cls, server_args: ServerArgs):
|
||||
if cls._instance is not None:
|
||||
return cls._instance
|
||||
|
||||
if server_args.elastic_ep_backend is not None:
|
||||
cls._instance = cls._build_state(ep_size=None, device=None)
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _select_device() -> torch.device:
|
||||
if is_cuda():
|
||||
return torch.device("cuda")
|
||||
elif is_cpu():
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")
|
||||
|
||||
@classmethod
|
||||
def _build_state(
|
||||
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
|
||||
) -> ElasticEPState:
|
||||
|
||||
active = cls.healthy_rank_state(ep_size=ep_size, device=device)
|
||||
return ElasticEPState(
|
||||
active_ranks=active,
|
||||
last_active_ranks=active.clone(),
|
||||
active_ranks_cpu=active.detach().cpu().clone(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def healthy_rank_state(
|
||||
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
|
||||
) -> torch.Tensor:
|
||||
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
|
||||
dev = device if device is not None else cls._select_device()
|
||||
|
||||
return torch.ones(size, dtype=torch.int32, device=dev)
|
||||
@@ -3,7 +3,8 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
|
||||
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
|
||||
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware
|
||||
|
||||
|
||||
class EplbAlgorithm(Enum):
|
||||
@@ -11,6 +12,7 @@ class EplbAlgorithm(Enum):
|
||||
deepseek_hierarchical = auto()
|
||||
deepseek_vec = auto()
|
||||
deepseek_vec_hierarchical = auto()
|
||||
elasticity_aware = auto()
|
||||
# TODO may have more algorithm later
|
||||
|
||||
|
||||
@@ -45,6 +47,21 @@ def rebalance_experts(
|
||||
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
|
||||
)
|
||||
|
||||
if algorithm == EplbAlgorithm.elasticity_aware:
|
||||
return elasticity_aware.rebalance_experts(
|
||||
weight=tokens_per_expert.sum(dim=0),
|
||||
num_replicas=num_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
num_gpus=num_physical_experts // num_local_physical_experts,
|
||||
enable_hierarchical=True,
|
||||
active_ranks=(
|
||||
ElasticEPStateManager.instance().active_ranks
|
||||
if ElasticEPStateManager.instance() is not None
|
||||
else ElasticEPStateManager.healthy_rank_state()
|
||||
),
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
87
python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Normal file
87
python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
enable_hierarchical: bool,
|
||||
active_ranks: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||
"""
|
||||
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float().cpu()
|
||||
num_active_ranks = active_ranks.sum().item()
|
||||
num_local_experts = num_replicas // num_gpus
|
||||
if num_active_ranks < num_gpus:
|
||||
# Must fall back to global load-balance policy
|
||||
# and fix some params
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight,
|
||||
num_local_experts * num_active_ranks,
|
||||
1,
|
||||
1,
|
||||
num_active_ranks,
|
||||
)
|
||||
elif enable_hierarchical:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
maxlogcnt = logcnt.max().item()
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(
|
||||
num_local_experts * num_active_ranks,
|
||||
dtype=torch.int64,
|
||||
device=log2phy.device,
|
||||
).expand(num_layers, -1),
|
||||
)
|
||||
if num_active_ranks < num_gpus:
|
||||
phy2log_slices = list(
|
||||
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
|
||||
)
|
||||
active_ranks_list = active_ranks.tolist()
|
||||
for idx, active_rank in enumerate(active_ranks_list):
|
||||
if not active_rank:
|
||||
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
|
||||
log2phy = torch.where(
|
||||
log2phy >= idx * num_local_experts,
|
||||
log2phy + num_local_experts,
|
||||
log2phy,
|
||||
)
|
||||
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
|
||||
return phy2log, log2phy, logcnt
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
|
||||
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
@@ -63,14 +64,6 @@ class MooncakeCombineInput(NamedTuple):
|
||||
assert isinstance(MooncakeCombineInput, CombineInput)
|
||||
|
||||
|
||||
_ACTIVE_RANKS: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_ep_active_ranks() -> torch.Tensor:
|
||||
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
|
||||
return _ACTIVE_RANKS
|
||||
|
||||
|
||||
class EPBuffer:
|
||||
_buffer = None
|
||||
_hidden_size: Optional[int] = None
|
||||
@@ -153,12 +146,7 @@ class _MooncakeEPDispatcherImpl:
|
||||
self.first_execution = True
|
||||
self.timeout_us = 10000000
|
||||
|
||||
global _ACTIVE_RANKS
|
||||
if _ACTIVE_RANKS is None:
|
||||
_ACTIVE_RANKS = torch.ones(
|
||||
(self.num_experts,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.active_ranks = _ACTIVE_RANKS
|
||||
self.active_ranks = ElasticEPStateManager.instance().active_ranks
|
||||
|
||||
self.handle = None
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.distributed import (
|
||||
set_symm_mem_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
|
||||
from sglang.srt.eplb.eplb_manager import EPLBManager
|
||||
from sglang.srt.eplb.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
@@ -379,6 +380,11 @@ class ModelRunner:
|
||||
)
|
||||
self.expert_location_updater = ExpertLocationUpdater()
|
||||
|
||||
(
|
||||
ElasticEPStateManager.init(self.server_args)
|
||||
if self.server_args.elastic_ep_backend
|
||||
else None
|
||||
)
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
@@ -956,16 +962,33 @@ class ModelRunner:
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
):
|
||||
self.expert_location_updater.update(
|
||||
self.model.routed_experts_weights_of_layer,
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=self.server_args.nnodes,
|
||||
rank=self.tp_rank,
|
||||
)
|
||||
if ElasticEPStateManager.instance() is not None:
|
||||
# TODO: refactor the weights update when elastic ep
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert old_expert_location_metadata is not None
|
||||
old_expert_location_metadata.update(
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
)
|
||||
self.update_weights_from_disk(
|
||||
self.server_args.model_path,
|
||||
self.server_args.load_format,
|
||||
lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
|
||||
)
|
||||
else:
|
||||
self.expert_location_updater.update(
|
||||
self.model.routed_experts_weights_of_layer,
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=self.server_args.nnodes,
|
||||
rank=self.tp_rank,
|
||||
)
|
||||
|
||||
def update_weights_from_disk(
|
||||
self, model_path: str, load_format: str
|
||||
self,
|
||||
model_path: str,
|
||||
load_format: str,
|
||||
weight_name_filter: Optional[Callable[[str], bool]] = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""Update engine weights in-place from the disk."""
|
||||
logger.info(
|
||||
@@ -987,6 +1010,11 @@ class ModelRunner:
|
||||
iter = loader._get_weights_iterator(
|
||||
DefaultModelLoader.Source.init_new(config, self.model)
|
||||
)
|
||||
if weight_name_filter is not None:
|
||||
iter = (
|
||||
(name, weight) for name, weight in iter if weight_name_filter(name)
|
||||
)
|
||||
|
||||
return iter
|
||||
|
||||
def model_load_weights(model, iter):
|
||||
|
||||
@@ -600,6 +600,9 @@ class ServerArgs:
|
||||
# Handle any other necessary validations.
|
||||
self._handle_other_validations()
|
||||
|
||||
# Handle elastic expert parallelism.
|
||||
self._handle_elastic_ep()
|
||||
|
||||
def _handle_deprecated_args(self):
|
||||
# handle deprecated tool call parsers
|
||||
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
|
||||
@@ -1225,6 +1228,15 @@ class ServerArgs:
|
||||
if self.enable_eplb:
|
||||
assert self.ep_size > 1
|
||||
|
||||
def _handle_elastic_ep(self):
|
||||
if self.elastic_ep_backend is not None:
|
||||
if self.enable_eplb:
|
||||
if self.eplb_algorithm == "auto":
|
||||
self.eplb_algorithm = "elasticity_aware"
|
||||
assert (
|
||||
self.eplb_algorithm == "elasticity_aware"
|
||||
), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'."
|
||||
|
||||
def _handle_expert_distribution_metrics(self):
|
||||
if self.enable_expert_distribution_metrics and (
|
||||
self.expert_distribution_recorder_mode is None
|
||||
|
||||
Reference in New Issue
Block a user