Add canary for EPLB rebalancing (#6895)
This commit is contained in:
@@ -61,7 +61,62 @@ class ExpertLocationUpdater:
|
||||
)
|
||||
|
||||
|
||||
def _update_expert_weights(
|
||||
def _update_expert_weights(**kwargs):
|
||||
if get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_CANARY"):
|
||||
return _update_expert_weights_with_canary(**kwargs)
|
||||
else:
|
||||
return _update_expert_weights_raw(**kwargs)
|
||||
|
||||
|
||||
# can add watchdog as well
|
||||
def _update_expert_weights_with_canary(
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
old_expert_location_metadata: ExpertLocationMetadata,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||
|
||||
def _get_canary_value(meta: ExpertLocationMetadata, layer_id: int):
|
||||
return meta.physical_to_logical_map_cpu[
|
||||
layer_id,
|
||||
num_local_physical_experts * rank : num_local_physical_experts * (rank + 1),
|
||||
]
|
||||
|
||||
routed_experts_weights_of_layer = {
|
||||
k: [x for x in v] for k, v in routed_experts_weights_of_layer.items()
|
||||
}
|
||||
for layer_id in update_layer_ids:
|
||||
canary_tensor = (
|
||||
_get_canary_value(old_expert_location_metadata, layer_id)
|
||||
.clone()
|
||||
.to(device=global_server_args_dict["device"], non_blocking=True)
|
||||
)
|
||||
routed_experts_weights_of_layer[layer_id].append(canary_tensor)
|
||||
|
||||
_update_expert_weights_raw(
|
||||
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata=old_expert_location_metadata,
|
||||
new_expert_location_metadata=new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=nnodes,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
for layer_id in update_layer_ids:
|
||||
# can optimize speed if needed
|
||||
expect_value = _get_canary_value(new_expert_location_metadata, layer_id)
|
||||
actual_value = routed_experts_weights_of_layer[layer_id][-1].cpu()
|
||||
assert torch.all(expect_value == actual_value), (
|
||||
f"{expect_value=} {actual_value=} {layer_id=} "
|
||||
f"{old_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
|
||||
f"{new_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
|
||||
)
|
||||
|
||||
|
||||
def _update_expert_weights_raw(
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
old_expert_location_metadata: ExpertLocationMetadata,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
|
||||
Reference in New Issue
Block a user