diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index 52cc13524..75fd7c2d7 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -61,7 +61,62 @@ class ExpertLocationUpdater: ) -def _update_expert_weights( +def _update_expert_weights(**kwargs): + if get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_CANARY"): + return _update_expert_weights_with_canary(**kwargs) + else: + return _update_expert_weights_raw(**kwargs) + + +# can add watchdog as well +def _update_expert_weights_with_canary( + routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], + old_expert_location_metadata: ExpertLocationMetadata, + new_expert_location_metadata: ExpertLocationMetadata, + update_layer_ids: List[int], + nnodes: int, + rank: int, +): + num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts + + def _get_canary_value(meta: ExpertLocationMetadata, layer_id: int): + return meta.physical_to_logical_map_cpu[ + layer_id, + num_local_physical_experts * rank : num_local_physical_experts * (rank + 1), + ] + + routed_experts_weights_of_layer = { + k: [x for x in v] for k, v in routed_experts_weights_of_layer.items() + } + for layer_id in update_layer_ids: + canary_tensor = ( + _get_canary_value(old_expert_location_metadata, layer_id) + .clone() + .to(device=global_server_args_dict["device"], non_blocking=True) + ) + routed_experts_weights_of_layer[layer_id].append(canary_tensor) + + _update_expert_weights_raw( + routed_experts_weights_of_layer=routed_experts_weights_of_layer, + old_expert_location_metadata=old_expert_location_metadata, + new_expert_location_metadata=new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=nnodes, + rank=rank, + ) + + for layer_id in update_layer_ids: + # can optimize speed if needed + expect_value = _get_canary_value(new_expert_location_metadata, layer_id) + actual_value = routed_experts_weights_of_layer[layer_id][-1].cpu() + assert torch.all(expect_value == actual_value), ( + f"{expect_value=} {actual_value=} {layer_id=} " + f"{old_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} " + f"{new_expert_location_metadata.physical_to_logical_map_cpu.tolist()=} " + ) + + +def _update_expert_weights_raw( routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], old_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata,