Support fake perfectly balanced EP dispatch algorithm (#6571)

This commit is contained in:
fzyzcjy
2025-05-26 13:35:51 +08:00
committed by GitHub
parent 2c3a6fe1de
commit 0ca1811715
3 changed files with 24 additions and 2 deletions

View File

@@ -55,6 +55,18 @@ class ExpertLocationDispatchInfo:
)
def transform_select_experts_inputs(
router_logits: torch.Tensor,
correction_bias: Optional[torch.Tensor],
info: Optional[ExpertLocationDispatchInfo],
):
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
router_logits = torch.randn_like(router_logits)
if correction_bias is not None:
correction_bias = torch.zeros_like(correction_bias)
return router_logits, correction_bias
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
@@ -63,7 +75,7 @@ def topk_ids_logical_to_physical(
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm == "dynamic":
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError