diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index ee5f2c7ca..4db273781 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -231,6 +231,7 @@ class ExpertLocationMetadata: logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=( compute_logical_to_rank_dispatch_physical_map( + server_args=server_args, logical_to_all_physical_map=logical_to_all_physical_map, num_gpus=ep_size, num_physical_experts=num_physical_experts, @@ -340,6 +341,7 @@ def _pad_nested_array(arr, pad_value): # TODO optimize performance (rewrite and/or run in separate process with overlap) def compute_logical_to_rank_dispatch_physical_map( + server_args: ServerArgs, logical_to_all_physical_map: torch.Tensor, num_gpus: int, num_physical_experts: int, @@ -348,7 +350,9 @@ def compute_logical_to_rank_dispatch_physical_map( ): r = random.Random(seed) - num_local_physical_experts = num_physical_experts // num_gpus + num_local_gpu_physical_experts = num_physical_experts // num_gpus + num_gpus_per_node = server_args.ep_size // server_args.nnodes + num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape dtype = logical_to_all_physical_map.dtype @@ -372,13 +376,28 @@ def compute_logical_to_rank_dispatch_physical_map( physical_expert_id for physical_expert_id in candidate_physical_expert_ids if _compute_gpu_id_of_physical_expert( - physical_expert_id, num_local_physical_experts + physical_expert_id, num_local_gpu_physical_experts ) == gpu_id ] if len(same_gpu_physical_expert_ids) > 0: + # 1. Prefer same-GPU experts output_partial[gpu_id] = same_gpu_physical_expert_ids[0] + else: + # 2. Otherwise, prefer same-node experts + node_id = gpu_id // num_gpus_per_node + same_node_physical_expert_ids = [ + physical_expert_id + for physical_expert_id in candidate_physical_expert_ids + if _compute_node_id_of_physical_expert( + physical_expert_id, num_local_node_physical_experts + ) + == node_id + ] + if len(same_node_physical_expert_ids) > 0: + output_partial[gpu_id] = same_node_physical_expert_ids[0] + # 3. Fill remaining slots with fair random choices num_remain = torch.sum(output_partial == -1).item() output_partial[output_partial == -1] = torch.tensor( _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), @@ -404,9 +423,15 @@ def _logical_to_all_physical_raw( def _compute_gpu_id_of_physical_expert( - physical_expert_id: int, num_local_physical_experts: int + physical_expert_id: int, num_local_gpu_physical_experts: int ) -> int: - return physical_expert_id // num_local_physical_experts + return physical_expert_id // num_local_gpu_physical_experts + + +def _compute_node_id_of_physical_expert( + physical_expert_id: int, num_local_host_physical_experts: int +) -> int: + return physical_expert_id // num_local_host_physical_experts def _fair_choices(arr: List, k: int, r: random.Random) -> List: