Fix EPLB algorithm fail to run when using 3 nodes for prefill (#6629)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||||
|
|
||||||
from typing import Literal, Tuple
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -257,11 +257,15 @@ def rebalance_experts(
|
|||||||
tokens_per_expert: torch.Tensor,
|
tokens_per_expert: torch.Tensor,
|
||||||
num_physical_experts: int,
|
num_physical_experts: int,
|
||||||
num_local_physical_experts: int,
|
num_local_physical_experts: int,
|
||||||
num_groups: int,
|
num_groups: Optional[int],
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
phase: Literal["prefill", "decode"],
|
phase: Literal["prefill", "decode", "null"],
|
||||||
):
|
):
|
||||||
if phase == "prefill":
|
if (
|
||||||
|
(phase == "prefill")
|
||||||
|
and (num_groups is not None)
|
||||||
|
and (num_groups % num_nodes == 0)
|
||||||
|
):
|
||||||
return prefill_rebalance_experts(
|
return prefill_rebalance_experts(
|
||||||
tokens_per_expert=tokens_per_expert,
|
tokens_per_expert=tokens_per_expert,
|
||||||
num_physical_experts=num_physical_experts,
|
num_physical_experts=num_physical_experts,
|
||||||
@@ -269,10 +273,8 @@ def rebalance_experts(
|
|||||||
num_groups=num_groups,
|
num_groups=num_groups,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
)
|
)
|
||||||
if phase == "decode":
|
return decode_rebalance_experts(
|
||||||
return decode_rebalance_experts(
|
tokens_per_expert=tokens_per_expert,
|
||||||
tokens_per_expert=tokens_per_expert,
|
num_physical_experts=num_physical_experts,
|
||||||
num_physical_experts=num_physical_experts,
|
num_local_physical_experts=num_local_physical_experts,
|
||||||
num_local_physical_experts=num_local_physical_experts,
|
)
|
||||||
)
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|||||||
@@ -135,10 +135,6 @@ class ExpertLocationMetadata:
|
|||||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
num_physical_experts = common["num_physical_experts"]
|
num_physical_experts = common["num_physical_experts"]
|
||||||
|
|
||||||
phase = server_args.disaggregation_mode
|
|
||||||
if phase == "null" or model_config_for_expert_location.num_groups is None:
|
|
||||||
phase = "decode"
|
|
||||||
|
|
||||||
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
||||||
deepseek_eplb.rebalance_experts(
|
deepseek_eplb.rebalance_experts(
|
||||||
tokens_per_expert=logical_count,
|
tokens_per_expert=logical_count,
|
||||||
@@ -146,7 +142,7 @@ class ExpertLocationMetadata:
|
|||||||
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
||||||
num_groups=model_config_for_expert_location.num_groups,
|
num_groups=model_config_for_expert_location.num_groups,
|
||||||
num_nodes=server_args.nnodes,
|
num_nodes=server_args.nnodes,
|
||||||
phase=phase,
|
phase=server_args.disaggregation_mode,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user