Support dynamically rebalancing experts using EPLB (#6469)
This commit is contained in:
55
python/sglang/srt/managers/eplb_manager.py
Normal file
55
python/sglang/srt/managers/eplb_manager.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.cuda
|
||||
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EPLBManager:
|
||||
def __init__(self, model_runner: "ModelRunner"):
|
||||
super().__init__()
|
||||
self._model_runner = model_runner
|
||||
self._server_args = model_runner.server_args
|
||||
|
||||
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
||||
assert (
|
||||
self._server_args.eplb_rebalance_num_iterations
|
||||
<= self._server_args.expert_distribution_recorder_buffer_size
|
||||
), "eplb_rebalance_num_iterations must be less than expert_distribution_recorder_buffer_size"
|
||||
|
||||
get_global_expert_distribution_recorder().start_record()
|
||||
|
||||
logger.info(
|
||||
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
|
||||
)
|
||||
|
||||
def on_forward_pass_end(self, forward_pass_id: int):
|
||||
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
|
||||
self.rebalance()
|
||||
|
||||
def rebalance(self):
|
||||
logger.info("[EPLBManager] rebalance start")
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
logical_count = get_global_expert_distribution_recorder().dump_record(
|
||||
output_mode="object"
|
||||
)["logical_count"]
|
||||
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
||||
self._server_args, self._model_runner.model_config, logical_count
|
||||
)
|
||||
self._model_runner.update_expert_location(expert_location_metadata)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")
|
||||
@@ -95,6 +95,8 @@ def update_expert_weights_single_layer(
|
||||
tensor.shape[0] == num_local_physical_experts
|
||||
for tensor in routed_experts_weights
|
||||
), f"{num_local_physical_experts=} {[x.shape for x in routed_experts_weights]=}"
|
||||
assert isinstance(old_physical_to_logical_map, list)
|
||||
assert isinstance(new_physical_to_logical_map, list)
|
||||
|
||||
output_logs = [] if debug else None
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.layers.quantization.deep_gemm import (
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.eplb_manager import EPLBManager
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
@@ -255,6 +256,12 @@ class ModelRunner:
|
||||
)
|
||||
)
|
||||
|
||||
self.eplb_manager = (
|
||||
EPLBManager(self)
|
||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||
else None
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
@@ -1152,10 +1159,15 @@ class ModelRunner:
|
||||
self.forward_pass_id,
|
||||
forward_batch,
|
||||
):
|
||||
return self._forward_raw(
|
||||
output = self._forward_raw(
|
||||
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
||||
)
|
||||
|
||||
if self.eplb_manager is not None:
|
||||
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_raw(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
|
||||
@@ -173,6 +173,8 @@ class ServerArgs:
|
||||
ep_num_redundant_experts: int = 0
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
enable_eplb: bool = False
|
||||
eplb_rebalance_num_iterations: int = 1000
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
] = None
|
||||
@@ -1293,6 +1295,17 @@ class ServerArgs:
|
||||
default=ServerArgs.init_expert_location,
|
||||
help="Initial location of EP experts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-eplb",
|
||||
action="store_true",
|
||||
help="Enable EPLB algorithm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eplb-rebalance-num-iterations",
|
||||
type=int,
|
||||
default=ServerArgs.eplb_rebalance_num_iterations,
|
||||
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert-distribution-recorder-mode",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user