diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 615e0a440..ea4c67a54 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -35,7 +35,8 @@ class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) - logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts) + # (layers, num_logical_experts) + logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] # -------------------------------- properties ------------------------------------ @@ -70,11 +71,8 @@ class ExpertLocationMetadata: num_layers_2, num_logical_experts_1 = ( self.logical_to_all_physical_map_num_valid.shape ) - num_layers_3, num_logical_experts_2 = ( - self.logical_to_rank_dispatch_physical_map.shape - ) - assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 - assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 + assert num_layers_0 == num_layers_1 == num_layers_2 + assert num_logical_experts_0 == num_logical_experts_1 assert num_physical_experts_0 == num_physical_experts_1 # -------------------------------- construction ------------------------------------ @@ -117,6 +115,7 @@ class ExpertLocationMetadata: ) return ExpertLocationMetadata._init_raw( + server_args=server_args, ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, @@ -154,6 +153,7 @@ class ExpertLocationMetadata: ) return ExpertLocationMetadata._init_raw( + server_args=server_args, ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map.to(server_args.device), logical_to_all_physical_map=logical_to_all_physical_map.to( @@ -184,6 +184,7 @@ class ExpertLocationMetadata: @staticmethod def _init_raw( + server_args: ServerArgs, ep_size: int, physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, @@ -204,12 +205,16 @@ class ExpertLocationMetadata: physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, - logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( - logical_to_all_physical_map=logical_to_all_physical_map, - num_gpus=ep_size, - num_physical_experts=num_physical_experts, - # TODO improve when we have real EP rank - ep_rank=torch.distributed.get_rank() % ep_size, + logical_to_rank_dispatch_physical_map=( + compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map=logical_to_all_physical_map, + num_gpus=ep_size, + num_physical_experts=num_physical_experts, + # TODO improve when we have real EP rank + ep_rank=torch.distributed.get_rank() % ep_size, + ) + if server_args.ep_dispatch_algorithm == "static" + else None ), ) @@ -230,8 +235,11 @@ class ExpertLocationMetadata: "logical_to_all_physical_map_num_valid", "logical_to_rank_dispatch_physical_map", ]: + src = getattr(other, field) dst = getattr(self, field) - dst[...] = getattr(other, field) + assert (src is not None) == (dst is not None) + if dst is not None: + dst[...] = src # -------------------------------- usage ------------------------------------ diff --git a/python/sglang/srt/managers/expert_location_dispatch.py b/python/sglang/srt/managers/expert_location_dispatch.py index 6880b01a2..547dd4e72 100644 --- a/python/sglang/srt/managers/expert_location_dispatch.py +++ b/python/sglang/srt/managers/expert_location_dispatch.py @@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict class ExpertLocationDispatchInfo: ep_dispatch_algorithm: Literal["static", "random"] # (num_logical_experts,) - partial_logical_to_rank_dispatch_physical_map: torch.Tensor + partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] # (num_logical_experts, X) partial_logical_to_all_physical_map: torch.Tensor # (num_logical_experts,) @@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo: return cls( ep_dispatch_algorithm=ep_dispatch_algorithm, - partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ - layer_id, : - ], + partial_logical_to_rank_dispatch_physical_map=( + expert_location_metadata.logical_to_rank_dispatch_physical_map[ + layer_id, : + ] + if expert_location_metadata.logical_to_rank_dispatch_physical_map + is not None + else None + ), partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[ layer_id, : ],