Add canary for EPLB rebalancing (#6895)
This commit is contained in:
@@ -61,7 +61,62 @@ class ExpertLocationUpdater:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _update_expert_weights(
|
def _update_expert_weights(**kwargs):
|
||||||
|
if get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_CANARY"):
|
||||||
|
return _update_expert_weights_with_canary(**kwargs)
|
||||||
|
else:
|
||||||
|
return _update_expert_weights_raw(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# can add watchdog as well
|
||||||
|
def _update_expert_weights_with_canary(
|
||||||
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
|
old_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
update_layer_ids: List[int],
|
||||||
|
nnodes: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||||
|
|
||||||
|
def _get_canary_value(meta: ExpertLocationMetadata, layer_id: int):
|
||||||
|
return meta.physical_to_logical_map_cpu[
|
||||||
|
layer_id,
|
||||||
|
num_local_physical_experts * rank : num_local_physical_experts * (rank + 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
routed_experts_weights_of_layer = {
|
||||||
|
k: [x for x in v] for k, v in routed_experts_weights_of_layer.items()
|
||||||
|
}
|
||||||
|
for layer_id in update_layer_ids:
|
||||||
|
canary_tensor = (
|
||||||
|
_get_canary_value(old_expert_location_metadata, layer_id)
|
||||||
|
.clone()
|
||||||
|
.to(device=global_server_args_dict["device"], non_blocking=True)
|
||||||
|
)
|
||||||
|
routed_experts_weights_of_layer[layer_id].append(canary_tensor)
|
||||||
|
|
||||||
|
_update_expert_weights_raw(
|
||||||
|
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||||
|
old_expert_location_metadata=old_expert_location_metadata,
|
||||||
|
new_expert_location_metadata=new_expert_location_metadata,
|
||||||
|
update_layer_ids=update_layer_ids,
|
||||||
|
nnodes=nnodes,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_id in update_layer_ids:
|
||||||
|
# can optimize speed if needed
|
||||||
|
expect_value = _get_canary_value(new_expert_location_metadata, layer_id)
|
||||||
|
actual_value = routed_experts_weights_of_layer[layer_id][-1].cpu()
|
||||||
|
assert torch.all(expect_value == actual_value), (
|
||||||
|
f"{expect_value=} {actual_value=} {layer_id=} "
|
||||||
|
f"{old_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
|
||||||
|
f"{new_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_expert_weights_raw(
|
||||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||||
old_expert_location_metadata: ExpertLocationMetadata,
|
old_expert_location_metadata: ExpertLocationMetadata,
|
||||||
new_expert_location_metadata: ExpertLocationMetadata,
|
new_expert_location_metadata: ExpertLocationMetadata,
|
||||||
|
|||||||
Reference in New Issue
Block a user