Support fake perfectly balanced EP dispatch algorithm (#6571)
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from sglang.srt.managers import expert_location_dispatch
|
||||||
from sglang.srt.managers.expert_distribution import (
|
from sglang.srt.managers.expert_distribution import (
|
||||||
ExpertDistributionRecorder,
|
ExpertDistributionRecorder,
|
||||||
get_global_expert_distribution_recorder,
|
get_global_expert_distribution_recorder,
|
||||||
@@ -310,6 +311,15 @@ def select_experts(
|
|||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
):
|
):
|
||||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||||
|
|
||||||
|
router_logits, correction_bias = (
|
||||||
|
expert_location_dispatch.transform_select_experts_inputs(
|
||||||
|
router_logits=router_logits,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
info=expert_location_dispatch_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert topk_group is not None
|
assert topk_group is not None
|
||||||
|
|||||||
@@ -55,6 +55,18 @@ class ExpertLocationDispatchInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_select_experts_inputs(
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
correction_bias: Optional[torch.Tensor],
|
||||||
|
info: Optional[ExpertLocationDispatchInfo],
|
||||||
|
):
|
||||||
|
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
||||||
|
router_logits = torch.randn_like(router_logits)
|
||||||
|
if correction_bias is not None:
|
||||||
|
correction_bias = torch.zeros_like(correction_bias)
|
||||||
|
return router_logits, correction_bias
|
||||||
|
|
||||||
|
|
||||||
def topk_ids_logical_to_physical(
|
def topk_ids_logical_to_physical(
|
||||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -63,7 +75,7 @@ def topk_ids_logical_to_physical(
|
|||||||
|
|
||||||
if info.ep_dispatch_algorithm == "static":
|
if info.ep_dispatch_algorithm == "static":
|
||||||
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
||||||
if info.ep_dispatch_algorithm == "dynamic":
|
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
|
||||||
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class ServerArgs:
|
|||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
ep_num_redundant_experts: int = 0
|
ep_num_redundant_experts: int = 0
|
||||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||||
init_expert_location: str = "trivial"
|
init_expert_location: str = "trivial"
|
||||||
enable_eplb: bool = False
|
enable_eplb: bool = False
|
||||||
eplb_rebalance_num_iterations: int = 1000
|
eplb_rebalance_num_iterations: int = 1000
|
||||||
|
|||||||
Reference in New Issue
Block a user