Improve EPLB logical to physical dispatch map (#6727)
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
|
|||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
|
||||||
num_gpus=ep_size,
|
num_gpus=ep_size,
|
||||||
num_physical_experts=num_physical_experts,
|
num_physical_experts=num_physical_experts,
|
||||||
ep_rank=torch.distributed.get_rank(),
|
# TODO improve when we have real EP rank
|
||||||
|
ep_rank=torch.distributed.get_rank() % ep_size,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
|
|||||||
return padded
|
return padded
|
||||||
|
|
||||||
|
|
||||||
# TODO use more sophisticated approaches
|
# TODO optimize performance (rewrite and/or run in separate process with overlap)
|
||||||
def compute_logical_to_rank_dispatch_physical_map(
|
def compute_logical_to_rank_dispatch_physical_map(
|
||||||
logical_to_all_physical_map: torch.Tensor,
|
logical_to_all_physical_map: torch.Tensor,
|
||||||
logical_to_all_physical_map_num_valid: torch.Tensor,
|
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
num_physical_experts: int,
|
num_physical_experts: int,
|
||||||
ep_rank: int,
|
ep_rank: int,
|
||||||
base_seed: int = 42,
|
seed: int = 42,
|
||||||
):
|
):
|
||||||
device = logical_to_all_physical_map.device
|
r = random.Random(seed)
|
||||||
|
|
||||||
num_local_physical_experts = num_physical_experts // num_gpus
|
num_local_physical_experts = num_physical_experts // num_gpus
|
||||||
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
||||||
|
dtype = logical_to_all_physical_map.dtype
|
||||||
|
|
||||||
g = torch.Generator(device=device)
|
logical_to_rank_dispatch_physical_map = torch.full(
|
||||||
g.manual_seed(base_seed + ep_rank)
|
size=(num_gpus, num_layers, num_logical_experts),
|
||||||
|
fill_value=-1,
|
||||||
output_shape = (num_layers, num_logical_experts)
|
dtype=dtype,
|
||||||
chosen_index = (
|
|
||||||
torch.randint(
|
|
||||||
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
|
|
||||||
)
|
)
|
||||||
% logical_to_all_physical_map_num_valid
|
|
||||||
)
|
|
||||||
logical_to_rank_dispatch_physical_map = torch.gather(
|
|
||||||
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
|
|
||||||
).squeeze(-1)
|
|
||||||
assert logical_to_rank_dispatch_physical_map.shape == output_shape
|
|
||||||
|
|
||||||
for index in range(logical_to_all_physical_map_num_valid.max().item()):
|
for layer_id in range(num_layers):
|
||||||
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
|
for logical_expert_id in range(num_logical_experts):
|
||||||
is_valid = partial_logical_to_all_physical_map != -1
|
candidate_physical_expert_ids = _logical_to_all_physical_raw(
|
||||||
is_same_gpu = (
|
logical_to_all_physical_map, layer_id, logical_expert_id
|
||||||
partial_logical_to_all_physical_map // num_local_physical_experts
|
)
|
||||||
) == ep_rank
|
output_partial = logical_to_rank_dispatch_physical_map[
|
||||||
logical_to_rank_dispatch_physical_map = torch.where(
|
:, layer_id, logical_expert_id
|
||||||
is_valid & is_same_gpu,
|
]
|
||||||
partial_logical_to_all_physical_map,
|
|
||||||
logical_to_rank_dispatch_physical_map,
|
for gpu_id in range(num_gpus):
|
||||||
|
same_gpu_physical_expert_ids = [
|
||||||
|
physical_expert_id
|
||||||
|
for physical_expert_id in candidate_physical_expert_ids
|
||||||
|
if _compute_gpu_id_of_physical_expert(
|
||||||
|
physical_expert_id, num_local_physical_experts
|
||||||
|
)
|
||||||
|
== gpu_id
|
||||||
|
]
|
||||||
|
if len(same_gpu_physical_expert_ids) > 0:
|
||||||
|
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
|
||||||
|
|
||||||
|
num_remain = torch.sum(output_partial == -1).item()
|
||||||
|
output_partial[output_partial == -1] = torch.tensor(
|
||||||
|
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
||||||
return logical_to_rank_dispatch_physical_map
|
|
||||||
|
device = logical_to_all_physical_map.device
|
||||||
|
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def _logical_to_all_physical_raw(
|
||||||
|
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
|
||||||
|
) -> List[int]:
|
||||||
|
return [
|
||||||
|
physical_expert_id
|
||||||
|
for physical_expert_id in logical_to_all_physical_map[
|
||||||
|
layer_id, logical_expert_id
|
||||||
|
].tolist()
|
||||||
|
if physical_expert_id != -1
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_gpu_id_of_physical_expert(
|
||||||
|
physical_expert_id: int, num_local_physical_experts: int
|
||||||
|
) -> int:
|
||||||
|
return physical_expert_id // num_local_physical_experts
|
||||||
|
|
||||||
|
|
||||||
|
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
|
||||||
|
quotient, remainder = divmod(k, len(arr))
|
||||||
|
ans = arr * quotient + r.sample(arr, k=remainder)
|
||||||
|
r.shuffle(ans)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user