From 904655c5fd746577b24ebc0b35bf6fe49daf649a Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 22 Oct 2025 16:13:31 +0800 Subject: [PATCH] [2/N] Added the core structure of elastic EP and the eplb algorithm with faulty rank (#10606) Co-authored-by: Xun Sun Co-authored-by: Shangming Cai --- python/sglang/srt/elastic_ep/elastic_ep.py | 74 +++++ .../srt/eplb/eplb_algorithms/__init__.py | 19 +- .../eplb/eplb_algorithms/elasticity_aware.py | 87 ++++++ .../layers/moe/token_dispatcher/mooncake.py | 16 +- .../sglang/srt/model_executor/model_runner.py | 46 ++- python/sglang/srt/server_args.py | 12 + test/srt/ep/test_mooncake_ep_small.py | 277 +++++------------- 7 files changed, 297 insertions(+), 234 deletions(-) create mode 100644 python/sglang/srt/elastic_ep/elastic_ep.py create mode 100644 python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py new file mode 100644 index 000000000..f3367980c --- /dev/null +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + +from sglang.srt.managers.schedule_batch import ServerArgs +from sglang.srt.utils import is_cpu, is_cuda + + +@dataclass +class ElasticEPState: + active_ranks: Optional[torch.Tensor] + last_active_ranks: Optional[torch.Tensor] + active_ranks_cpu: Optional[torch.Tensor] + + def is_active_equal_last(self) -> bool: + return torch.equal(self.active_ranks, self.last_active_ranks) + + def sync_active_to_cpu(self): + if self.active_ranks is not None: + self.active_ranks_cpu = self.active_ranks.detach().cpu().clone() + + def snapshot_active_to_last(self): + if self.active_ranks is not None: + self.last_active_ranks = self.active_ranks.clone() + + +class ElasticEPStateManager: + _instance: Optional[ElasticEPState] = None + + @classmethod + def instance(cls) -> ElasticEPState: + return cls._instance + + @classmethod + def init(cls, server_args: ServerArgs): + if cls._instance is not None: + return cls._instance + + if server_args.elastic_ep_backend is not None: + cls._instance = cls._build_state(ep_size=None, device=None) + return cls._instance + + @staticmethod + def _select_device() -> torch.device: + if is_cuda(): + return torch.device("cuda") + elif is_cpu(): + return torch.device("cpu") + else: + raise NotImplementedError("Only CUDA and CPU support elastic ep now.") + + @classmethod + def _build_state( + cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None + ) -> ElasticEPState: + + active = cls.healthy_rank_state(ep_size=ep_size, device=device) + return ElasticEPState( + active_ranks=active, + last_active_ranks=active.clone(), + active_ranks_cpu=active.detach().cpu().clone(), + ) + + @classmethod + def healthy_rank_state( + cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None + ) -> torch.Tensor: + size = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else cls._select_device() + + return torch.ones(size, dtype=torch.int32, device=dev) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index e2a267810..fc4d8f0f8 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,8 @@ from typing import Optional import torch -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware class EplbAlgorithm(Enum): @@ -11,6 +12,7 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() + elasticity_aware = auto() # TODO may have more algorithm later @@ -45,6 +47,21 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) + if algorithm == EplbAlgorithm.elasticity_aware: + return elasticity_aware.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + enable_hierarchical=True, + active_ranks=( + ElasticEPStateManager.instance().active_ranks + if ElasticEPStateManager.instance() is not None + else ElasticEPStateManager.healthy_rank_state() + ), + ) + raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py new file mode 100644 index 000000000..c781c444a --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py @@ -0,0 +1,87 @@ +from typing import Tuple + +import torch + +from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, + active_ranks: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + num_active_ranks = active_ranks.sum().item() + num_local_experts = num_replicas // num_gpus + if num_active_ranks < num_gpus: + # Must fall back to global load-balance policy + # and fix some params + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, + num_local_experts * num_active_ranks, + 1, + 1, + num_active_ranks, + ) + elif enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange( + num_local_experts * num_active_ranks, + dtype=torch.int64, + device=log2phy.device, + ).expand(num_layers, -1), + ) + if num_active_ranks < num_gpus: + phy2log_slices = list( + phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) + ) + active_ranks_list = active_ranks.tolist() + for idx, active_rank in enumerate(active_ranks_list): + if not active_rank: + phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) + log2phy = torch.where( + log2phy >= idx * num_local_experts, + log2phy + num_local_experts, + log2phy, + ) + phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) + return phy2log, log2phy, logcnt diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index f195a7994..d21e46740 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,6 +4,7 @@ import logging from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import get_is_extend_in_batch from sglang.srt.layers.moe.token_dispatcher.base import ( @@ -63,14 +64,6 @@ class MooncakeCombineInput(NamedTuple): assert isinstance(MooncakeCombineInput, CombineInput) -_ACTIVE_RANKS: Optional[torch.Tensor] = None - - -def get_ep_active_ranks() -> torch.Tensor: - assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized" - return _ACTIVE_RANKS - - class EPBuffer: _buffer = None _hidden_size: Optional[int] = None @@ -153,12 +146,7 @@ class _MooncakeEPDispatcherImpl: self.first_execution = True self.timeout_us = 10000000 - global _ACTIVE_RANKS - if _ACTIVE_RANKS is None: - _ACTIVE_RANKS = torch.ones( - (self.num_experts,), dtype=torch.int32, device="cuda" - ) - self.active_ranks = _ACTIVE_RANKS + self.active_ranks = ElasticEPStateManager.instance().active_ranks self.handle = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e779597a5..d06b28124 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,7 @@ import threading import time from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ from sglang.srt.distributed import ( set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -379,6 +380,11 @@ class ModelRunner: ) self.expert_location_updater = ExpertLocationUpdater() + ( + ElasticEPStateManager.init(self.server_args) + if self.server_args.elastic_ep_backend + else None + ) # Load the model self.sampler = Sampler() self.load_model() @@ -956,16 +962,33 @@ class ModelRunner: new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - self.expert_location_updater.update( - self.model.routed_experts_weights_of_layer, - new_expert_location_metadata, - update_layer_ids=update_layer_ids, - nnodes=self.server_args.nnodes, - rank=self.tp_rank, - ) + if ElasticEPStateManager.instance() is not None: + # TODO: refactor the weights update when elastic ep + old_expert_location_metadata = get_global_expert_location_metadata() + assert old_expert_location_metadata is not None + old_expert_location_metadata.update( + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + ) + self.update_weights_from_disk( + self.server_args.model_path, + self.server_args.load_format, + lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name, + ) + else: + self.expert_location_updater.update( + self.model.routed_experts_weights_of_layer, + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=self.server_args.nnodes, + rank=self.tp_rank, + ) def update_weights_from_disk( - self, model_path: str, load_format: str + self, + model_path: str, + load_format: str, + weight_name_filter: Optional[Callable[[str], bool]] = None, ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" logger.info( @@ -987,6 +1010,11 @@ class ModelRunner: iter = loader._get_weights_iterator( DefaultModelLoader.Source.init_new(config, self.model) ) + if weight_name_filter is not None: + iter = ( + (name, weight) for name, weight in iter if weight_name_filter(name) + ) + return iter def model_load_weights(model, iter): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0c2bbf6a3..7435725dc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -600,6 +600,9 @@ class ServerArgs: # Handle any other necessary validations. self._handle_other_validations() + # Handle elastic expert parallelism. + self._handle_elastic_ep() + def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1225,6 +1228,15 @@ class ServerArgs: if self.enable_eplb: assert self.ep_size > 1 + def _handle_elastic_ep(self): + if self.elastic_ep_backend is not None: + if self.enable_eplb: + if self.eplb_algorithm == "auto": + self.eplb_algorithm = "elasticity_aware" + assert ( + self.eplb_algorithm == "elasticity_aware" + ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." + def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( self.expert_distribution_recorder_mode is None diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 111260a8c..391cdc4c6 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -3,6 +3,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,8 +12,12 @@ from sglang.test.test_utils import ( popen_launch_server, ) +ib_devices = get_rdma_devices_args() + + +class TestTP(CustomTestCase): + extra_args = [] -class TestPureDP(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -25,13 +30,10 @@ class TestPureDP(CustomTestCase): "--trust-remote-code", "--tp", "4", - "--enable-dp-attention", - "--dp", - "4", "--elastic-ep-backend", "mooncake", "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + ib_devices, "--moe-a2a-backend", "deepep", "--deepep-mode", @@ -44,6 +46,7 @@ class TestPureDP(CustomTestCase): "512", "--mem-fraction-static", "0.5", + *cls.extra_args, ], ) @@ -67,219 +70,73 @@ class TestPureDP(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.60) -class TestHybridDPTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "256", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestPureDP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + ] -class TestTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "128", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestHybridDPTP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + ] -class TestNoGatherdBuffer(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "32", - "--max-running-requests", - "512", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestNoGatherdBuffer(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + ] -class TestTBO(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--enable-two-batch-overlap", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - ], - ) +class TestTBO(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + ] - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestMooncakeWitchEPLB(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + ] if __name__ == "__main__":