Support layerwise rebalancing experts (#6851)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
@@ -20,6 +20,10 @@ class EPLBManager:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._model_runner = model_runner
|
self._model_runner = model_runner
|
||||||
self._server_args = model_runner.server_args
|
self._server_args = model_runner.server_args
|
||||||
|
self._rebalance_layers_per_chunk = (
|
||||||
|
self._server_args.eplb_rebalance_layers_per_chunk
|
||||||
|
)
|
||||||
|
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
|
||||||
|
|
||||||
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
||||||
assert (
|
assert (
|
||||||
@@ -31,17 +35,30 @@ class EPLBManager:
|
|||||||
get_global_expert_distribution_recorder().start_record()
|
get_global_expert_distribution_recorder().start_record()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
|
f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_forward_pass_end(self, forward_pass_id: int):
|
self._main_generator = self._entrypoint()
|
||||||
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
|
|
||||||
self.rebalance()
|
def on_forward_pass_end(self):
|
||||||
|
next(self._main_generator)
|
||||||
|
|
||||||
|
# can be more complex if needed
|
||||||
|
def _entrypoint(self):
|
||||||
|
while True:
|
||||||
|
for _ in range(self._rebalance_num_iterations):
|
||||||
|
yield
|
||||||
|
|
||||||
|
yield from self.rebalance()
|
||||||
|
|
||||||
def rebalance(self):
|
def rebalance(self):
|
||||||
logger.info("[EPLBManager] rebalance start")
|
logger.info("[EPLBManager] rebalance start")
|
||||||
torch.cuda.synchronize()
|
|
||||||
time_start = time.time()
|
enable_timing = self._rebalance_layers_per_chunk is None
|
||||||
|
|
||||||
|
if enable_timing:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
logical_count = get_global_expert_distribution_recorder().dump_record(
|
logical_count = get_global_expert_distribution_recorder().dump_record(
|
||||||
output_mode="object"
|
output_mode="object"
|
||||||
@@ -49,8 +66,31 @@ class EPLBManager:
|
|||||||
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
||||||
self._server_args, self._model_runner.model_config, logical_count
|
self._server_args, self._model_runner.model_config, logical_count
|
||||||
)
|
)
|
||||||
self._model_runner.update_expert_location(expert_location_metadata)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
|
||||||
time_end = time.time()
|
for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
|
||||||
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")
|
if len(update_layer_ids_chunks) > 1:
|
||||||
|
yield
|
||||||
|
self._model_runner.update_expert_location(
|
||||||
|
expert_location_metadata,
|
||||||
|
update_layer_ids=update_layer_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f"[EPLBManager] rebalance end"
|
||||||
|
if enable_timing:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_end = time.time()
|
||||||
|
msg += f" time={time_end - time_start:.3f}s"
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
|
||||||
|
all_layer_ids = sorted(
|
||||||
|
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
|
||||||
|
)
|
||||||
|
chunk_size = self._rebalance_layers_per_chunk or 1000000
|
||||||
|
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_list(items: List, chunk_size):
|
||||||
|
for start_index in range(0, len(items), chunk_size):
|
||||||
|
yield items[start_index : start_index + chunk_size]
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ExpertLocationMetadata:
|
class ExpertLocationMetadata:
|
||||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||||
|
physical_to_logical_map_cpu: torch.Tensor
|
||||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||||
# (layers, num_logical_experts)
|
# (layers, num_logical_experts)
|
||||||
@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
|
|||||||
|
|
||||||
return ExpertLocationMetadata(
|
return ExpertLocationMetadata(
|
||||||
physical_to_logical_map=physical_to_logical_map,
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
|
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
logical_to_rank_dispatch_physical_map=(
|
logical_to_rank_dispatch_physical_map=(
|
||||||
@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
other: "ExpertLocationMetadata",
|
other: "ExpertLocationMetadata",
|
||||||
|
update_layer_ids: List[int],
|
||||||
):
|
):
|
||||||
for field in [
|
for field in [
|
||||||
"ep_size",
|
"ep_size",
|
||||||
@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
|
|||||||
|
|
||||||
for field in [
|
for field in [
|
||||||
"physical_to_logical_map",
|
"physical_to_logical_map",
|
||||||
|
"physical_to_logical_map_cpu",
|
||||||
"logical_to_all_physical_map",
|
"logical_to_all_physical_map",
|
||||||
"logical_to_all_physical_map_num_valid",
|
"logical_to_all_physical_map_num_valid",
|
||||||
"logical_to_rank_dispatch_physical_map",
|
"logical_to_rank_dispatch_physical_map",
|
||||||
]:
|
]:
|
||||||
src = getattr(other, field)
|
other_field = getattr(other, field)
|
||||||
dst = getattr(self, field)
|
self_field = getattr(self, field)
|
||||||
assert (src is not None) == (dst is not None)
|
assert (other_field is not None) == (self_field is not None)
|
||||||
if dst is not None:
|
if self_field is not None:
|
||||||
dst[...] = src
|
mask_update = torch.tensor(
|
||||||
|
[i in update_layer_ids for i in range(self.num_layers)]
|
||||||
|
)
|
||||||
|
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
|
||||||
|
mask_update = mask_update.to(self_field.device, non_blocking=True)
|
||||||
|
self_field[...] = torch.where(mask_update, other_field, self_field)
|
||||||
|
|
||||||
# -------------------------------- usage ------------------------------------
|
# -------------------------------- usage ------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
|
|||||||
ExpertLocationMetadata,
|
ExpertLocationMetadata,
|
||||||
get_global_expert_location_metadata,
|
get_global_expert_location_metadata,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
|
|||||||
self,
|
self,
|
||||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
new_expert_location_metadata: ExpertLocationMetadata,
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
update_layer_ids: List[int],
|
||||||
nnodes: int,
|
nnodes: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
):
|
):
|
||||||
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
|
|||||||
|
|
||||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
_update_expert_weights(
|
_update_expert_weights(
|
||||||
routed_experts_weights_of_layer,
|
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||||
old_expert_location_metadata,
|
old_expert_location_metadata=old_expert_location_metadata,
|
||||||
new_expert_location_metadata,
|
new_expert_location_metadata=new_expert_location_metadata,
|
||||||
nnodes,
|
update_layer_ids=update_layer_ids,
|
||||||
rank,
|
nnodes=nnodes,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
old_expert_location_metadata.update(
|
||||||
|
new_expert_location_metadata,
|
||||||
|
update_layer_ids=update_layer_ids,
|
||||||
)
|
)
|
||||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def _update_expert_weights(
|
def _update_expert_weights(
|
||||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
old_expert_location_metadata: ExpertLocationMetadata,
|
old_expert_location_metadata: ExpertLocationMetadata,
|
||||||
new_expert_location_metadata: ExpertLocationMetadata,
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
update_layer_ids: List[int],
|
||||||
nnodes: int,
|
nnodes: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
):
|
):
|
||||||
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
|
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
|
||||||
|
|
||||||
temp_buffers = create_temp_buffers(
|
temp_buffers = create_temp_buffers(
|
||||||
next(iter(routed_experts_weights_of_layer.values()))
|
routed_experts_weights_of_layer[update_layer_ids[0]]
|
||||||
)
|
)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||||
num_gpu_per_node = world_size // nnodes
|
num_gpu_per_node = world_size // nnodes
|
||||||
|
|
||||||
old_physical_to_logical_map = (
|
for layer_id in update_layer_ids:
|
||||||
old_expert_location_metadata.physical_to_logical_map.tolist()
|
|
||||||
)
|
|
||||||
new_physical_to_logical_map = (
|
|
||||||
new_expert_location_metadata.physical_to_logical_map.tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
|
|
||||||
update_expert_weights_single_layer(
|
update_expert_weights_single_layer(
|
||||||
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
||||||
temp_buffers=temp_buffers,
|
temp_buffers=temp_buffers,
|
||||||
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
|
old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
|
||||||
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
|
layer_id
|
||||||
|
].tolist(),
|
||||||
|
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
|
||||||
|
layer_id
|
||||||
|
].tolist(),
|
||||||
num_local_physical_experts=num_local_physical_experts,
|
num_local_physical_experts=num_local_physical_experts,
|
||||||
num_gpu_per_node=num_gpu_per_node,
|
num_gpu_per_node=num_gpu_per_node,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
|||||||
@@ -611,11 +611,14 @@ class ModelRunner:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
def update_expert_location(
|
def update_expert_location(
|
||||||
self, new_expert_location_metadata: ExpertLocationMetadata
|
self,
|
||||||
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
update_layer_ids: List[int],
|
||||||
):
|
):
|
||||||
self.expert_location_updater.update(
|
self.expert_location_updater.update(
|
||||||
self.model.routed_experts_weights_of_layer,
|
self.model.routed_experts_weights_of_layer,
|
||||||
new_expert_location_metadata,
|
new_expert_location_metadata,
|
||||||
|
update_layer_ids=update_layer_ids,
|
||||||
nnodes=self.server_args.nnodes,
|
nnodes=self.server_args.nnodes,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
)
|
)
|
||||||
@@ -1203,7 +1206,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.eplb_manager is not None:
|
if self.eplb_manager is not None:
|
||||||
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
self.eplb_manager.on_forward_pass_end()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class ServerArgs:
|
|||||||
enable_eplb: bool = False
|
enable_eplb: bool = False
|
||||||
eplb_algorithm: str = "auto"
|
eplb_algorithm: str = "auto"
|
||||||
eplb_rebalance_num_iterations: int = 1000
|
eplb_rebalance_num_iterations: int = 1000
|
||||||
|
eplb_rebalance_layers_per_chunk: Optional[int] = None
|
||||||
expert_distribution_recorder_mode: Optional[
|
expert_distribution_recorder_mode: Optional[
|
||||||
Literal["stat", "per_pass", "per_token"]
|
Literal["stat", "per_pass", "per_token"]
|
||||||
] = None
|
] = None
|
||||||
@@ -1367,6 +1368,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.eplb_rebalance_num_iterations,
|
default=ServerArgs.eplb_rebalance_num_iterations,
|
||||||
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eplb-rebalance-layers-per-chunk",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.eplb_rebalance_layers_per_chunk,
|
||||||
|
help="Number of layers to rebalance per forward pass.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--expert-distribution-recorder-mode",
|
"--expert-distribution-recorder-mode",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -17,7 +16,9 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDynamicEPLB(CustomTestCase):
|
class _BaseTestDynamicEPLB(CustomTestCase):
|
||||||
|
extra_args = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
@@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase):
|
|||||||
"stat",
|
"stat",
|
||||||
"--ep-dispatch-algorithm",
|
"--ep-dispatch-algorithm",
|
||||||
"static",
|
"static",
|
||||||
|
*cls.extra_args,
|
||||||
],
|
],
|
||||||
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
|
env={
|
||||||
|
"SGL_ENABLE_JIT_DEEPGEMM": "0",
|
||||||
|
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
|
||||||
|
**os.environ,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase):
|
|||||||
self.assertGreater(metrics["score"], 0.5)
|
self.assertGreater(metrics["score"], 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
|
||||||
|
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
|
||||||
|
|
||||||
|
|
||||||
class TestStaticEPLB(CustomTestCase):
|
class TestStaticEPLB(CustomTestCase):
|
||||||
def test_save_expert_distribution_and_init_expert_location(self):
|
def test_save_expert_distribution_and_init_expert_location(self):
|
||||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
|
||||||
|
|||||||
Reference in New Issue
Block a user