Support dynamically rebalancing experts using EPLB (#6469)
This commit is contained in:
55
python/sglang/srt/managers/eplb_manager.py
Normal file
55
python/sglang/srt/managers/eplb_manager.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.cuda
|
||||
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EPLBManager:
|
||||
def __init__(self, model_runner: "ModelRunner"):
|
||||
super().__init__()
|
||||
self._model_runner = model_runner
|
||||
self._server_args = model_runner.server_args
|
||||
|
||||
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
||||
assert (
|
||||
self._server_args.eplb_rebalance_num_iterations
|
||||
<= self._server_args.expert_distribution_recorder_buffer_size
|
||||
), "eplb_rebalance_num_iterations must be less than expert_distribution_recorder_buffer_size"
|
||||
|
||||
get_global_expert_distribution_recorder().start_record()
|
||||
|
||||
logger.info(
|
||||
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
|
||||
)
|
||||
|
||||
def on_forward_pass_end(self, forward_pass_id: int):
|
||||
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
|
||||
self.rebalance()
|
||||
|
||||
def rebalance(self):
|
||||
logger.info("[EPLBManager] rebalance start")
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
logical_count = get_global_expert_distribution_recorder().dump_record(
|
||||
output_mode="object"
|
||||
)["logical_count"]
|
||||
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
||||
self._server_args, self._model_runner.model_config, logical_count
|
||||
)
|
||||
self._model_runner.update_expert_location(expert_location_metadata)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")
|
||||
@@ -95,6 +95,8 @@ def update_expert_weights_single_layer(
|
||||
tensor.shape[0] == num_local_physical_experts
|
||||
for tensor in routed_experts_weights
|
||||
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
|
||||
assert isinstance(old_physical_to_logical_map, list)
|
||||
assert isinstance(new_physical_to_logical_map, list)
|
||||
|
||||
output_logs = [] if debug else None
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.layers.quantization.deep_gemm import (
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.eplb_manager import EPLBManager
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
@@ -255,6 +256,12 @@ class ModelRunner:
|
||||
)
|
||||
)
|
||||
|
||||
self.eplb_manager = (
|
||||
EPLBManager(self)
|
||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||
else None
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
@@ -1152,10 +1159,15 @@ class ModelRunner:
|
||||
self.forward_pass_id,
|
||||
forward_batch,
|
||||
):
|
||||
return self._forward_raw(
|
||||
output = self._forward_raw(
|
||||
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
||||
)
|
||||
|
||||
if self.eplb_manager is not None:
|
||||
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_raw(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
|
||||
@@ -173,6 +173,8 @@ class ServerArgs:
|
||||
ep_num_redundant_experts: int = 0
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
enable_eplb: bool = False
|
||||
eplb_rebalance_num_iterations: int = 1000
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
] = None
|
||||
@@ -1293,6 +1295,17 @@ class ServerArgs:
|
||||
default=ServerArgs.init_expert_location,
|
||||
help="Initial location of EP experts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-eplb",
|
||||
action="store_true",
|
||||
help="Enable EPLB algorithm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eplb-rebalance-num-iterations",
|
||||
type=int,
|
||||
default=ServerArgs.eplb_rebalance_num_iterations,
|
||||
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert-distribution-recorder-mode",
|
||||
type=str,
|
||||
|
||||
141
test/srt/test_eplb.py
Executable file
141
test/srt/test_eplb.py
Executable file
@@ -0,0 +1,141 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestDynamicEPLB(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--dp",
|
||||
"2",
|
||||
"--enable-dp-attention",
|
||||
"--enable-deepep-moe",
|
||||
"--deepep-mode",
|
||||
"normal",
|
||||
"--disable-cuda-graph",
|
||||
"--enable-eplb",
|
||||
"--ep-num-redundant-experts",
|
||||
"4",
|
||||
"--eplb-rebalance-num-iterations",
|
||||
"50",
|
||||
"--expert-distribution-recorder-buffer-size",
|
||||
"50",
|
||||
# TODO pr-chain: enable later
|
||||
# "--enable-expert-distribution-metrics",
|
||||
# TODO auto determine these flags
|
||||
"--expert-distribution-recorder-mode",
|
||||
"stat",
|
||||
"--ep-dispatch-algorithm",
|
||||
"static",
|
||||
],
|
||||
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.5)
|
||||
|
||||
|
||||
class TestStaticEPLB(CustomTestCase):
|
||||
def test_save_expert_distribution_and_init_expert_location(self):
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
engine_kwargs = dict(
|
||||
model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
trust_remote_code=True,
|
||||
ep_num_redundant_experts=4,
|
||||
enable_dp_attention=True,
|
||||
enable_deepep_moe=True,
|
||||
deepep_mode="normal",
|
||||
disable_cuda_graph=True,
|
||||
expert_distribution_recorder_mode="stat",
|
||||
tp_size=2,
|
||||
dp_size=2,
|
||||
log_level="info",
|
||||
# TODO pr-chain: enable later
|
||||
# enable_expert_distribution_metrics=True,
|
||||
)
|
||||
|
||||
print(f"Action: start engine")
|
||||
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
|
||||
engine = sgl.Engine(
|
||||
**engine_kwargs,
|
||||
disable_overlap_schedule=True,
|
||||
)
|
||||
engine.start_expert_distribution_record()
|
||||
self._assert_engine_generate_correct(engine)
|
||||
|
||||
print(f"Action: dump_expert_distribution_record")
|
||||
engine.dump_expert_distribution_record()
|
||||
snapshot_path = list(Path(tmp_dir).glob("*.pt"))[0]
|
||||
assert snapshot_path is not None
|
||||
print(f"{snapshot_path=}")
|
||||
|
||||
print(f"Action: shutdown engine")
|
||||
engine.shutdown()
|
||||
del engine
|
||||
|
||||
print(f"Action: start engine with init_expert_location")
|
||||
engine = sgl.Engine(
|
||||
**engine_kwargs,
|
||||
init_expert_location=str(snapshot_path),
|
||||
port=21000,
|
||||
# TODO auto determine these flags
|
||||
ep_dispatch_algorithm="static",
|
||||
)
|
||||
self._assert_engine_generate_correct(engine)
|
||||
print(f"Action: shutdown engine")
|
||||
engine.shutdown()
|
||||
del engine
|
||||
|
||||
def _assert_engine_generate_correct(self, engine: sgl.Engine):
|
||||
output = engine.generate(
|
||||
prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"],
|
||||
sampling_params=dict(max_new_tokens=8, temperature=0.0),
|
||||
)
|
||||
print(f"engine.generate {output=}")
|
||||
self.assertEqual(
|
||||
[x["text"] for x in output],
|
||||
[", 4+4=8,", ", four plus four is eight, eight"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -210,8 +210,8 @@ def _execute_test(info: _TestInfo, rank: int, num_gpus: int, device: str):
|
||||
temp_buffers=expert_location_updater.create_temp_buffers(
|
||||
routed_experts_weights
|
||||
),
|
||||
old_physical_to_logical_map=physical_to_logical_map,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
old_physical_to_logical_map=physical_to_logical_map.tolist(),
|
||||
new_physical_to_logical_map=new_physical_to_logical_map.tolist(),
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_gpu_per_node=num_gpu_per_node,
|
||||
rank=rank,
|
||||
|
||||
Reference in New Issue
Block a user