Fix EPLB algorithm fail to run when using 3 nodes for prefill (#6629)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||
|
||||
from typing import Literal, Tuple
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -257,11 +257,15 @@ def rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
phase: Literal["prefill", "decode"],
|
||||
phase: Literal["prefill", "decode", "null"],
|
||||
):
|
||||
if phase == "prefill":
|
||||
if (
|
||||
(phase == "prefill")
|
||||
and (num_groups is not None)
|
||||
and (num_groups % num_nodes == 0)
|
||||
):
|
||||
return prefill_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
@@ -269,10 +273,8 @@ def rebalance_experts(
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
)
|
||||
if phase == "decode":
|
||||
return decode_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
raise NotImplementedError
|
||||
return decode_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
|
||||
@@ -135,10 +135,6 @@ class ExpertLocationMetadata:
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
|
||||
phase = server_args.disaggregation_mode
|
||||
if phase == "null" or model_config_for_expert_location.num_groups is None:
|
||||
phase = "decode"
|
||||
|
||||
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
||||
deepseek_eplb.rebalance_experts(
|
||||
tokens_per_expert=logical_count,
|
||||
@@ -146,7 +142,7 @@ class ExpertLocationMetadata:
|
||||
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
||||
num_groups=model_config_for_expert_location.num_groups,
|
||||
num_nodes=server_args.nnodes,
|
||||
phase=phase,
|
||||
phase=server_args.disaggregation_mode,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user