Speed up rebalancing when using non-static dispatch algorithms (#6812)
This commit is contained in:
@@ -35,7 +35,8 @@ class ExpertLocationMetadata:
|
|||||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||||
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
|
# (layers, num_logical_experts)
|
||||||
|
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
# -------------------------------- properties ------------------------------------
|
# -------------------------------- properties ------------------------------------
|
||||||
|
|
||||||
@@ -70,11 +71,8 @@ class ExpertLocationMetadata:
|
|||||||
num_layers_2, num_logical_experts_1 = (
|
num_layers_2, num_logical_experts_1 = (
|
||||||
self.logical_to_all_physical_map_num_valid.shape
|
self.logical_to_all_physical_map_num_valid.shape
|
||||||
)
|
)
|
||||||
num_layers_3, num_logical_experts_2 = (
|
assert num_layers_0 == num_layers_1 == num_layers_2
|
||||||
self.logical_to_rank_dispatch_physical_map.shape
|
assert num_logical_experts_0 == num_logical_experts_1
|
||||||
)
|
|
||||||
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
|
||||||
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
|
||||||
assert num_physical_experts_0 == num_physical_experts_1
|
assert num_physical_experts_0 == num_physical_experts_1
|
||||||
|
|
||||||
# -------------------------------- construction ------------------------------------
|
# -------------------------------- construction ------------------------------------
|
||||||
@@ -117,6 +115,7 @@ class ExpertLocationMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ExpertLocationMetadata._init_raw(
|
return ExpertLocationMetadata._init_raw(
|
||||||
|
server_args=server_args,
|
||||||
ep_size=common["ep_size"],
|
ep_size=common["ep_size"],
|
||||||
physical_to_logical_map=physical_to_logical_map,
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||||
@@ -154,6 +153,7 @@ class ExpertLocationMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ExpertLocationMetadata._init_raw(
|
return ExpertLocationMetadata._init_raw(
|
||||||
|
server_args=server_args,
|
||||||
ep_size=common["ep_size"],
|
ep_size=common["ep_size"],
|
||||||
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
|
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map.to(
|
logical_to_all_physical_map=logical_to_all_physical_map.to(
|
||||||
@@ -184,6 +184,7 @@ class ExpertLocationMetadata:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_raw(
|
def _init_raw(
|
||||||
|
server_args: ServerArgs,
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
physical_to_logical_map: torch.Tensor,
|
physical_to_logical_map: torch.Tensor,
|
||||||
logical_to_all_physical_map: torch.Tensor,
|
logical_to_all_physical_map: torch.Tensor,
|
||||||
@@ -204,12 +205,16 @@ class ExpertLocationMetadata:
|
|||||||
physical_to_logical_map=physical_to_logical_map,
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
logical_to_rank_dispatch_physical_map=(
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
compute_logical_to_rank_dispatch_physical_map(
|
||||||
num_gpus=ep_size,
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||||
num_physical_experts=num_physical_experts,
|
num_gpus=ep_size,
|
||||||
# TODO improve when we have real EP rank
|
num_physical_experts=num_physical_experts,
|
||||||
ep_rank=torch.distributed.get_rank() % ep_size,
|
# TODO improve when we have real EP rank
|
||||||
|
ep_rank=torch.distributed.get_rank() % ep_size,
|
||||||
|
)
|
||||||
|
if server_args.ep_dispatch_algorithm == "static"
|
||||||
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -230,8 +235,11 @@ class ExpertLocationMetadata:
|
|||||||
"logical_to_all_physical_map_num_valid",
|
"logical_to_all_physical_map_num_valid",
|
||||||
"logical_to_rank_dispatch_physical_map",
|
"logical_to_rank_dispatch_physical_map",
|
||||||
]:
|
]:
|
||||||
|
src = getattr(other, field)
|
||||||
dst = getattr(self, field)
|
dst = getattr(self, field)
|
||||||
dst[...] = getattr(other, field)
|
assert (src is not None) == (dst is not None)
|
||||||
|
if dst is not None:
|
||||||
|
dst[...] = src
|
||||||
|
|
||||||
# -------------------------------- usage ------------------------------------
|
# -------------------------------- usage ------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
class ExpertLocationDispatchInfo:
|
class ExpertLocationDispatchInfo:
|
||||||
ep_dispatch_algorithm: Literal["static", "random"]
|
ep_dispatch_algorithm: Literal["static", "random"]
|
||||||
# (num_logical_experts,)
|
# (num_logical_experts,)
|
||||||
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
|
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||||
# (num_logical_experts, X)
|
# (num_logical_experts, X)
|
||||||
partial_logical_to_all_physical_map: torch.Tensor
|
partial_logical_to_all_physical_map: torch.Tensor
|
||||||
# (num_logical_experts,)
|
# (num_logical_experts,)
|
||||||
@@ -42,9 +42,14 @@ class ExpertLocationDispatchInfo:
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
||||||
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
partial_logical_to_rank_dispatch_physical_map=(
|
||||||
layer_id, :
|
expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
||||||
],
|
layer_id, :
|
||||||
|
]
|
||||||
|
if expert_location_metadata.logical_to_rank_dispatch_physical_map
|
||||||
|
is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
||||||
layer_id, :
|
layer_id, :
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user