From cbbb738371a183f4a1eace147c9614ae6c8a2037 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Tue, 5 Aug 2025 22:09:52 -0700 Subject: [PATCH] [2/3] Optimize Slime Update Weights: Avoid GPU-to-CPU Device Sync when update expert weights (#8753) --- python/sglang/srt/eplb/expert_location.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index ef35ce7a6..be0e23653 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -35,6 +35,7 @@ class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) physical_to_logical_map_cpu: torch.Tensor logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) # (layers, num_logical_experts) logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] @@ -221,6 +222,7 @@ class ExpertLocationMetadata: physical_to_logical_map=physical_to_logical_map, physical_to_logical_map_cpu=physical_to_logical_map.cpu(), logical_to_all_physical_map=logical_to_all_physical_map_padded, + logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(), logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=( compute_logical_to_rank_dispatch_physical_map( @@ -251,6 +253,7 @@ class ExpertLocationMetadata: "physical_to_logical_map", "physical_to_logical_map_cpu", "logical_to_all_physical_map", + "logical_to_all_physical_map_cpu", "logical_to_all_physical_map_num_valid", "logical_to_rank_dispatch_physical_map", ]: @@ -270,9 +273,10 @@ class ExpertLocationMetadata: def logical_to_all_physical( self, layer_id: int, logical_expert_id: int ) -> List[int]: + # Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario return [ physical_expert_id - for physical_expert_id in self.logical_to_all_physical_map[ + for physical_expert_id in self.logical_to_all_physical_map_cpu[ layer_id, logical_expert_id ].tolist() if physical_expert_id != -1