Improve EPLB logical to physical dispatch map (#6727)

This commit is contained in:
fzyzcjy
2025-05-30 10:23:54 +08:00
committed by GitHub
parent 51cdd81f97
commit 2c3b71d678

View File

@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
import json import json
import logging import logging
import random
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
num_gpus=ep_size, num_gpus=ep_size,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank(), # TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
), ),
) )
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
return padded return padded
# TODO use more sophisticated approaches # TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map( def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor,
logical_to_all_physical_map_num_valid: torch.Tensor,
num_gpus: int, num_gpus: int,
num_physical_experts: int, num_physical_experts: int,
ep_rank: int, ep_rank: int,
base_seed: int = 42, seed: int = 42,
): ):
device = logical_to_all_physical_map.device r = random.Random(seed)
num_local_physical_experts = num_physical_experts // num_gpus num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
g = torch.Generator(device=device) logical_to_rank_dispatch_physical_map = torch.full(
g.manual_seed(base_seed + ep_rank) size=(num_gpus, num_layers, num_logical_experts),
fill_value=-1,
output_shape = (num_layers, num_logical_experts) dtype=dtype,
chosen_index = (
torch.randint(
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
) )
% logical_to_all_physical_map_num_valid
)
logical_to_rank_dispatch_physical_map = torch.gather(
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
).squeeze(-1)
assert logical_to_rank_dispatch_physical_map.shape == output_shape
for index in range(logical_to_all_physical_map_num_valid.max().item()): for layer_id in range(num_layers):
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] for logical_expert_id in range(num_logical_experts):
is_valid = partial_logical_to_all_physical_map != -1 candidate_physical_expert_ids = _logical_to_all_physical_raw(
is_same_gpu = ( logical_to_all_physical_map, layer_id, logical_expert_id
partial_logical_to_all_physical_map // num_local_physical_experts )
) == ep_rank output_partial = logical_to_rank_dispatch_physical_map[
logical_to_rank_dispatch_physical_map = torch.where( :, layer_id, logical_expert_id
is_valid & is_same_gpu, ]
partial_logical_to_all_physical_map,
logical_to_rank_dispatch_physical_map, for gpu_id in range(num_gpus):
same_gpu_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
dtype=dtype,
) )
assert torch.all(logical_to_rank_dispatch_physical_map != -1) assert torch.all(logical_to_rank_dispatch_physical_map != -1)
return logical_to_rank_dispatch_physical_map
device = logical_to_all_physical_map.device
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
def _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
quotient, remainder = divmod(k, len(arr))
ans = arr * quotient + r.sample(arr, k=remainder)
r.shuffle(ans)
return ans
@dataclass @dataclass