Speed up rebalancing when using non-static dispatch algorithms (#6812)
This commit is contained in:
@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
|
||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
|
||||
# (layers, num_logical_experts)
|
||||
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||
|
||||
# -------------------------------- properties ------------------------------------
|
||||
|
||||
@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
|
||||
num_layers_2, num_logical_experts_1 = (
|
||||
self.logical_to_all_physical_map_num_valid.shape
|
||||
)
|
||||
num_layers_3, num_logical_experts_2 = (
|
||||
self.logical_to_rank_dispatch_physical_map.shape
|
||||
)
|
||||
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
||||
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
||||
assert num_layers_0 == num_layers_1 == num_layers_2
|
||||
assert num_logical_experts_0 == num_logical_experts_1
|
||||
assert num_physical_experts_0 == num_physical_experts_1
|
||||
|
||||
# -------------------------------- construction ------------------------------------
|
||||
@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata._init_raw(
|
||||
server_args=server_args,
|
||||
ep_size=common["ep_size"],
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata._init_raw(
|
||||
server_args=server_args,
|
||||
ep_size=common["ep_size"],
|
||||
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
|
||||
logical_to_all_physical_map=logical_to_all_physical_map.to(
|
||||
@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
|
||||
|
||||
@staticmethod
|
||||
def _init_raw(
|
||||
server_args: ServerArgs,
|
||||
ep_size: int,
|
||||
physical_to_logical_map: torch.Tensor,
|
||||
logical_to_all_physical_map: torch.Tensor,
|
||||
@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
num_gpus=ep_size,
|
||||
num_physical_experts=num_physical_experts,
|
||||
# TODO improve when we have real EP rank
|
||||
ep_rank=torch.distributed.get_rank() % ep_size,
|
||||
logical_to_rank_dispatch_physical_map=(
|
||||
compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
num_gpus=ep_size,
|
||||
num_physical_experts=num_physical_experts,
|
||||
# TODO improve when we have real EP rank
|
||||
ep_rank=torch.distributed.get_rank() % ep_size,
|
||||
)
|
||||
if server_args.ep_dispatch_algorithm == "static"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
|
||||
"logical_to_all_physical_map_num_valid",
|
||||
"logical_to_rank_dispatch_physical_map",
|
||||
]:
|
||||
src = getattr(other, field)
|
||||
dst = getattr(self, field)
|
||||
dst[...] = getattr(other, field)
|
||||
assert (src is not None) == (dst is not None)
|
||||
if dst is not None:
|
||||
dst[...] = src
|
||||
|
||||
# -------------------------------- usage ------------------------------------
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
class ExpertLocationDispatchInfo:
|
||||
ep_dispatch_algorithm: Literal["static", "random"]
|
||||
# (num_logical_experts,)
|
||||
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
|
||||
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||
# (num_logical_experts, X)
|
||||
partial_logical_to_all_physical_map: torch.Tensor
|
||||
# (num_logical_experts,)
|
||||
@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
|
||||
|
||||
return cls(
|
||||
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
||||
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
||||
layer_id, :
|
||||
],
|
||||
partial_logical_to_rank_dispatch_physical_map=(
|
||||
expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
||||
layer_id, :
|
||||
]
|
||||
if expert_location_metadata.logical_to_rank_dispatch_physical_map
|
||||
is not None
|
||||
else None
|
||||
),
|
||||
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
||||
layer_id, :
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user