diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index e0bca7ec2..d12fd4975 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -13,6 +13,7 @@ # ============================================================================== import json import logging +import random from dataclasses import dataclass from pathlib import Path from typing import List, Optional @@ -205,10 +206,10 @@ class ExpertLocationMetadata: logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=logical_to_all_physical_map, - logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, num_gpus=ep_size, num_physical_experts=num_physical_experts, - ep_rank=torch.distributed.get_rank(), + # TODO improve when we have real EP rank + ep_rank=torch.distributed.get_rank() % ep_size, ), ) @@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value): return padded -# TODO use more sophisticated approaches +# TODO optimize performance (rewrite and/or run in separate process with overlap) def compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, - logical_to_all_physical_map_num_valid: torch.Tensor, num_gpus: int, num_physical_experts: int, ep_rank: int, - base_seed: int = 42, + seed: int = 42, ): - device = logical_to_all_physical_map.device + r = random.Random(seed) num_local_physical_experts = num_physical_experts // num_gpus num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + dtype = logical_to_all_physical_map.dtype - g = torch.Generator(device=device) - g.manual_seed(base_seed + ep_rank) - - output_shape = (num_layers, num_logical_experts) - chosen_index = ( - torch.randint( - 0, 65536, output_shape, dtype=torch.int32, device=device, generator=g - ) - % logical_to_all_physical_map_num_valid + logical_to_rank_dispatch_physical_map = torch.full( + size=(num_gpus, num_layers, num_logical_experts), + fill_value=-1, + dtype=dtype, ) - logical_to_rank_dispatch_physical_map = torch.gather( - logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1) - ).squeeze(-1) - assert logical_to_rank_dispatch_physical_map.shape == output_shape - for index in range(logical_to_all_physical_map_num_valid.max().item()): - partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] - is_valid = partial_logical_to_all_physical_map != -1 - is_same_gpu = ( - partial_logical_to_all_physical_map // num_local_physical_experts - ) == ep_rank - logical_to_rank_dispatch_physical_map = torch.where( - is_valid & is_same_gpu, - partial_logical_to_all_physical_map, - logical_to_rank_dispatch_physical_map, - ) + for layer_id in range(num_layers): + for logical_expert_id in range(num_logical_experts): + candidate_physical_expert_ids = _logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id, logical_expert_id + ) + output_partial = logical_to_rank_dispatch_physical_map[ + :, layer_id, logical_expert_id + ] + + for gpu_id in range(num_gpus): + same_gpu_physical_expert_ids = [ + physical_expert_id + for physical_expert_id in candidate_physical_expert_ids + if _compute_gpu_id_of_physical_expert( + physical_expert_id, num_local_physical_experts + ) + == gpu_id + ] + if len(same_gpu_physical_expert_ids) > 0: + output_partial[gpu_id] = same_gpu_physical_expert_ids[0] + + num_remain = torch.sum(output_partial == -1).item() + output_partial[output_partial == -1] = torch.tensor( + _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), + dtype=dtype, + ) assert torch.all(logical_to_rank_dispatch_physical_map != -1) - return logical_to_rank_dispatch_physical_map + + device = logical_to_all_physical_map.device + return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device) + + +def _logical_to_all_physical_raw( + logical_to_all_physical_map, layer_id: int, logical_expert_id: int +) -> List[int]: + return [ + physical_expert_id + for physical_expert_id in logical_to_all_physical_map[ + layer_id, logical_expert_id + ].tolist() + if physical_expert_id != -1 + ] + + +def _compute_gpu_id_of_physical_expert( + physical_expert_id: int, num_local_physical_experts: int +) -> int: + return physical_expert_id // num_local_physical_experts + + +def _fair_choices(arr: List, k: int, r: random.Random) -> List: + quotient, remainder = divmod(k, len(arr)) + ans = arr * quotient + r.sample(arr, k=remainder) + r.shuffle(ans) + return ans @dataclass