diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py new file mode 100644 index 000000000..d9e503264 --- /dev/null +++ b/python/sglang/srt/managers/eplb_manager.py @@ -0,0 +1,55 @@ +import logging +import time +from typing import TYPE_CHECKING + +import torch.cuda + +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) +from sglang.srt.managers.expert_location import ExpertLocationMetadata + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + +class EPLBManager: + def __init__(self, model_runner: "ModelRunner"): + super().__init__() + self._model_runner = model_runner + self._server_args = model_runner.server_args + + # Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented. + assert ( + self._server_args.eplb_rebalance_num_iterations + <= self._server_args.expert_distribution_recorder_buffer_size + ), "eplb_rebalance_num_iterations must be less than expert_distribution_recorder_buffer_size" + + get_global_expert_distribution_recorder().start_record() + + logger.info( + f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations." + ) + + def on_forward_pass_end(self, forward_pass_id: int): + if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0: + self.rebalance() + + def rebalance(self): + logger.info("[EPLBManager] rebalance start") + torch.cuda.synchronize() + time_start = time.time() + + logical_count = get_global_expert_distribution_recorder().dump_record( + output_mode="object" + )["logical_count"] + expert_location_metadata = ExpertLocationMetadata.init_by_eplb( + self._server_args, self._model_runner.model_config, logical_count + ) + self._model_runner.update_expert_location(expert_location_metadata) + + torch.cuda.synchronize() + time_end = time.time() + logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s") diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index ba179e95c..266b79283 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -95,6 +95,8 @@ def update_expert_weights_single_layer( tensor.shape[0] == num_local_physical_experts for tensor in routed_experts_weights ), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}" + assert isinstance(old_physical_to_logical_map, list) + assert isinstance(new_physical_to_logical_map, list) output_logs = [] if debug else None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1ac6371a4..4b4c61a23 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -51,6 +51,7 @@ from sglang.srt.layers.quantization.deep_gemm import ( from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.managers.eplb_manager import EPLBManager from sglang.srt.managers.expert_distribution import ( ExpertDistributionRecorder, get_global_expert_distribution_recorder, @@ -255,6 +256,12 @@ class ModelRunner: ) ) + self.eplb_manager = ( + EPLBManager(self) + if self.server_args.enable_eplb and (not self.is_draft_worker) + else None + ) + # Load the model self.sampler = Sampler() self.load_model() @@ -1152,10 +1159,15 @@ class ModelRunner: self.forward_pass_id, forward_batch, ): - return self._forward_raw( + output = self._forward_raw( forward_batch, skip_attn_backend_init, pp_proxy_tensors ) + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end(self.forward_pass_id) + + return output + def _forward_raw( self, forward_batch: ForwardBatch, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 23be9bca5..9402128df 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -173,6 +173,8 @@ class ServerArgs: ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None init_expert_location: str = "trivial" + enable_eplb: bool = False + eplb_rebalance_num_iterations: int = 1000 expert_distribution_recorder_mode: Optional[ Literal["stat", "per_pass", "per_token"] ] = None @@ -1293,6 +1295,17 @@ class ServerArgs: default=ServerArgs.init_expert_location, help="Initial location of EP experts.", ) + parser.add_argument( + "--enable-eplb", + action="store_true", + help="Enable EPLB algorithm", + ) + parser.add_argument( + "--eplb-rebalance-num-iterations", + type=int, + default=ServerArgs.eplb_rebalance_num_iterations, + help="Number of iterations to automatically trigger a EPLB re-balance.", + ) parser.add_argument( "--expert-distribution-recorder-mode", type=str, diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py new file mode 100755 index 000000000..f9c6fad20 --- /dev/null +++ b/test/srt/test_eplb.py @@ -0,0 +1,141 @@ +import os +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace + +import sglang as sgl +from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestDynamicEPLB(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--enable-deepep-moe", + "--deepep-mode", + "normal", + "--disable-cuda-graph", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + # TODO pr-chain: enable later + # "--enable-expert-distribution-metrics", + # TODO auto determine these flags + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + ], + env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +class TestStaticEPLB(CustomTestCase): + def test_save_expert_distribution_and_init_expert_location(self): + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0" + + with tempfile.TemporaryDirectory() as tmp_dir: + engine_kwargs = dict( + model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST, + trust_remote_code=True, + ep_num_redundant_experts=4, + enable_dp_attention=True, + enable_deepep_moe=True, + deepep_mode="normal", + disable_cuda_graph=True, + expert_distribution_recorder_mode="stat", + tp_size=2, + dp_size=2, + log_level="info", + # TODO pr-chain: enable later + # enable_expert_distribution_metrics=True, + ) + + print(f"Action: start engine") + os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir + engine = sgl.Engine( + **engine_kwargs, + disable_overlap_schedule=True, + ) + engine.start_expert_distribution_record() + self._assert_engine_generate_correct(engine) + + print(f"Action: dump_expert_distribution_record") + engine.dump_expert_distribution_record() + snapshot_path = list(Path(tmp_dir).glob("*.pt"))[0] + assert snapshot_path is not None + print(f"{snapshot_path=}") + + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + print(f"Action: start engine with init_expert_location") + engine = sgl.Engine( + **engine_kwargs, + init_expert_location=str(snapshot_path), + port=21000, + # TODO auto determine these flags + ep_dispatch_algorithm="static", + ) + self._assert_engine_generate_correct(engine) + print(f"Action: shutdown engine") + engine.shutdown() + del engine + + def _assert_engine_generate_correct(self, engine: sgl.Engine): + output = engine.generate( + prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"], + sampling_params=dict(max_new_tokens=8, temperature=0.0), + ) + print(f"engine.generate {output=}") + self.assertEqual( + [x["text"] for x in output], + [", 4+4=8,", ", four plus four is eight, eight"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_expert_location_updater.py b/test/srt/test_expert_location_updater.py index 49f98fd53..2b1249b1b 100644 --- a/test/srt/test_expert_location_updater.py +++ b/test/srt/test_expert_location_updater.py @@ -210,8 +210,8 @@ def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str): temp_buffers=expert_location_updater.create_temp_buffers( routed_experts_weights ), - old_physical_to_logical_map=physical_to_logical_map, - new_physical_to_logical_map=new_physical_to_logical_map, + old_physical_to_logical_map=physical_to_logical_map.tolist(), + new_physical_to_logical_map=new_physical_to_logical_map.tolist(), num_local_physical_experts=num_local_physical_experts, num_gpu_per_node=num_gpu_per_node, rank=rank,