Support layerwise rebalancing experts (#6851)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch.cuda
|
||||
|
||||
@@ -20,6 +20,10 @@ class EPLBManager:
|
||||
super().__init__()
|
||||
self._model_runner = model_runner
|
||||
self._server_args = model_runner.server_args
|
||||
self._rebalance_layers_per_chunk = (
|
||||
self._server_args.eplb_rebalance_layers_per_chunk
|
||||
)
|
||||
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
|
||||
|
||||
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
||||
assert (
|
||||
@@ -31,17 +35,30 @@ class EPLBManager:
|
||||
get_global_expert_distribution_recorder().start_record()
|
||||
|
||||
logger.info(
|
||||
f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations."
|
||||
f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
|
||||
)
|
||||
|
||||
def on_forward_pass_end(self, forward_pass_id: int):
|
||||
if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0:
|
||||
self.rebalance()
|
||||
self._main_generator = self._entrypoint()
|
||||
|
||||
def on_forward_pass_end(self):
|
||||
next(self._main_generator)
|
||||
|
||||
# can be more complex if needed
|
||||
def _entrypoint(self):
|
||||
while True:
|
||||
for _ in range(self._rebalance_num_iterations):
|
||||
yield
|
||||
|
||||
yield from self.rebalance()
|
||||
|
||||
def rebalance(self):
|
||||
logger.info("[EPLBManager] rebalance start")
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
enable_timing = self._rebalance_layers_per_chunk is None
|
||||
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
logical_count = get_global_expert_distribution_recorder().dump_record(
|
||||
output_mode="object"
|
||||
@@ -49,8 +66,31 @@ class EPLBManager:
|
||||
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
||||
self._server_args, self._model_runner.model_config, logical_count
|
||||
)
|
||||
self._model_runner.update_expert_location(expert_location_metadata)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s")
|
||||
update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
|
||||
for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
|
||||
if len(update_layer_ids_chunks) > 1:
|
||||
yield
|
||||
self._model_runner.update_expert_location(
|
||||
expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
)
|
||||
|
||||
msg = f"[EPLBManager] rebalance end"
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
msg += f" time={time_end - time_start:.3f}s"
|
||||
logger.info(msg)
|
||||
|
||||
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
|
||||
all_layer_ids = sorted(
|
||||
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
|
||||
)
|
||||
chunk_size = self._rebalance_layers_per_chunk or 1000000
|
||||
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
|
||||
|
||||
|
||||
def _chunk_list(items: List, chunk_size):
|
||||
for start_index in range(0, len(items), chunk_size):
|
||||
yield items[start_index : start_index + chunk_size]
|
||||
|
||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class ExpertLocationMetadata:
|
||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||
physical_to_logical_map_cpu: torch.Tensor
|
||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||
# (layers, num_logical_experts)
|
||||
@@ -203,6 +204,7 @@ class ExpertLocationMetadata:
|
||||
|
||||
return ExpertLocationMetadata(
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
logical_to_rank_dispatch_physical_map=(
|
||||
@@ -223,6 +225,7 @@ class ExpertLocationMetadata:
|
||||
def update(
|
||||
self,
|
||||
other: "ExpertLocationMetadata",
|
||||
update_layer_ids: List[int],
|
||||
):
|
||||
for field in [
|
||||
"ep_size",
|
||||
@@ -231,15 +234,21 @@ class ExpertLocationMetadata:
|
||||
|
||||
for field in [
|
||||
"physical_to_logical_map",
|
||||
"physical_to_logical_map_cpu",
|
||||
"logical_to_all_physical_map",
|
||||
"logical_to_all_physical_map_num_valid",
|
||||
"logical_to_rank_dispatch_physical_map",
|
||||
]:
|
||||
src = getattr(other, field)
|
||||
dst = getattr(self, field)
|
||||
assert (src is not None) == (dst is not None)
|
||||
if dst is not None:
|
||||
dst[...] = src
|
||||
other_field = getattr(other, field)
|
||||
self_field = getattr(self, field)
|
||||
assert (other_field is not None) == (self_field is not None)
|
||||
if self_field is not None:
|
||||
mask_update = torch.tensor(
|
||||
[i in update_layer_ids for i in range(self.num_layers)]
|
||||
)
|
||||
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
|
||||
mask_update = mask_update.to(self_field.device, non_blocking=True)
|
||||
self_field[...] = torch.where(mask_update, other_field, self_field)
|
||||
|
||||
# -------------------------------- usage ------------------------------------
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
|
||||
ExpertLocationMetadata,
|
||||
get_global_expert_location_metadata,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
|
||||
self,
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
_update_expert_weights(
|
||||
routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata,
|
||||
new_expert_location_metadata,
|
||||
nnodes,
|
||||
rank,
|
||||
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata=old_expert_location_metadata,
|
||||
new_expert_location_metadata=new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=nnodes,
|
||||
rank=rank,
|
||||
)
|
||||
old_expert_location_metadata.update(
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
)
|
||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||
|
||||
|
||||
def _update_expert_weights(
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
old_expert_location_metadata: ExpertLocationMetadata,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
|
||||
|
||||
temp_buffers = create_temp_buffers(
|
||||
next(iter(routed_experts_weights_of_layer.values()))
|
||||
routed_experts_weights_of_layer[update_layer_ids[0]]
|
||||
)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||
num_gpu_per_node = world_size // nnodes
|
||||
|
||||
old_physical_to_logical_map = (
|
||||
old_expert_location_metadata.physical_to_logical_map.tolist()
|
||||
)
|
||||
new_physical_to_logical_map = (
|
||||
new_expert_location_metadata.physical_to_logical_map.tolist()
|
||||
)
|
||||
|
||||
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
|
||||
for layer_id in update_layer_ids:
|
||||
update_expert_weights_single_layer(
|
||||
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
||||
temp_buffers=temp_buffers,
|
||||
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
|
||||
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
|
||||
old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
|
||||
layer_id
|
||||
].tolist(),
|
||||
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
|
||||
layer_id
|
||||
].tolist(),
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_gpu_per_node=num_gpu_per_node,
|
||||
rank=rank,
|
||||
|
||||
@@ -611,11 +611,14 @@ class ModelRunner:
|
||||
) from None
|
||||
|
||||
def update_expert_location(
|
||||
self, new_expert_location_metadata: ExpertLocationMetadata
|
||||
self,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
):
|
||||
self.expert_location_updater.update(
|
||||
self.model.routed_experts_weights_of_layer,
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=self.server_args.nnodes,
|
||||
rank=self.tp_rank,
|
||||
)
|
||||
@@ -1203,7 +1206,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.eplb_manager is not None:
|
||||
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
||||
self.eplb_manager.on_forward_pass_end()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -180,6 +180,7 @@ class ServerArgs:
|
||||
enable_eplb: bool = False
|
||||
eplb_algorithm: str = "auto"
|
||||
eplb_rebalance_num_iterations: int = 1000
|
||||
eplb_rebalance_layers_per_chunk: Optional[int] = None
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
] = None
|
||||
@@ -1367,6 +1368,12 @@ class ServerArgs:
|
||||
default=ServerArgs.eplb_rebalance_num_iterations,
|
||||
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eplb-rebalance-layers-per-chunk",
|
||||
type=int,
|
||||
default=ServerArgs.eplb_rebalance_layers_per_chunk,
|
||||
help="Number of layers to rebalance per forward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert-distribution-recorder-mode",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user