Files
sglang/python/sglang/srt/managers/eplb_manager.py

57 lines
2.0 KiB
Python

import logging
import time
from typing import TYPE_CHECKING
import torch.cuda
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ExpertLocationMetadata
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
logger = logging.getLogger(__name__)
class EPLBManager:
def __init__(self, model_runner: "ModelRunner"):
super().__init__()
self._model_runner = model_runner
self._server_args = model_runner.server_args
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert (
self._server_args.eplb_rebalance_num_iterations
>= self._server_args.expert_distribution_recorder_buffer_size
), "eplb_rebalance_num_iterations must be greater than expert_distribution_recorder_buffer_size"
if not get_global_expert_distribution_recorder().recording:
get_global_expert_distribution_recorder().start_record()
logger.info(
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
)
def on_forward_pass_end(self, forward_pass_id: int):
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
self.rebalance()
def rebalance(self):
logger.info("[EPLBManager] rebalance start")
torch.cuda.synchronize()
time_start = time.time()
logical_count = get_global_expert_distribution_recorder().dump_record(
output_mode="object"
)["logical_count"]
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
self._server_args, self._model_runner.model_config, logical_count
)
self._model_runner.update_expert_location(expert_location_metadata)
torch.cuda.synchronize()
time_end = time.time()
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")