[2/3] Optimize Slime Update Weights: Avoid GPU-to-CPU Device Sync when update expert weights (#8753)
This commit is contained in:
@@ -35,6 +35,7 @@ class ExpertLocationMetadata:
|
|||||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||||
physical_to_logical_map_cpu: torch.Tensor
|
physical_to_logical_map_cpu: torch.Tensor
|
||||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||||
|
logical_to_all_physical_map_cpu: torch.Tensor # CPU copy for performance
|
||||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||||
# (layers, num_logical_experts)
|
# (layers, num_logical_experts)
|
||||||
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||||
@@ -221,6 +222,7 @@ class ExpertLocationMetadata:
|
|||||||
physical_to_logical_map=physical_to_logical_map,
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||||
|
logical_to_all_physical_map_cpu=logical_to_all_physical_map_padded.cpu(),
|
||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
logical_to_rank_dispatch_physical_map=(
|
logical_to_rank_dispatch_physical_map=(
|
||||||
compute_logical_to_rank_dispatch_physical_map(
|
compute_logical_to_rank_dispatch_physical_map(
|
||||||
@@ -251,6 +253,7 @@ class ExpertLocationMetadata:
|
|||||||
"physical_to_logical_map",
|
"physical_to_logical_map",
|
||||||
"physical_to_logical_map_cpu",
|
"physical_to_logical_map_cpu",
|
||||||
"logical_to_all_physical_map",
|
"logical_to_all_physical_map",
|
||||||
|
"logical_to_all_physical_map_cpu",
|
||||||
"logical_to_all_physical_map_num_valid",
|
"logical_to_all_physical_map_num_valid",
|
||||||
"logical_to_rank_dispatch_physical_map",
|
"logical_to_rank_dispatch_physical_map",
|
||||||
]:
|
]:
|
||||||
@@ -270,9 +273,10 @@ class ExpertLocationMetadata:
|
|||||||
def logical_to_all_physical(
|
def logical_to_all_physical(
|
||||||
self, layer_id: int, logical_expert_id: int
|
self, layer_id: int, logical_expert_id: int
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
|
||||||
return [
|
return [
|
||||||
physical_expert_id
|
physical_expert_id
|
||||||
for physical_expert_id in self.logical_to_all_physical_map[
|
for physical_expert_id in self.logical_to_all_physical_map_cpu[
|
||||||
layer_id, logical_expert_id
|
layer_id, logical_expert_id
|
||||||
].tolist()
|
].tolist()
|
||||||
if physical_expert_id != -1
|
if physical_expert_id != -1
|
||||||
|
|||||||
Reference in New Issue
Block a user