Support fake perfectly balanced EP dispatch algorithm (#6571)
This commit is contained in:
@@ -55,6 +55,18 @@ class ExpertLocationDispatchInfo:
|
||||
)
|
||||
|
||||
|
||||
def transform_select_experts_inputs(
|
||||
router_logits: torch.Tensor,
|
||||
correction_bias: Optional[torch.Tensor],
|
||||
info: Optional[ExpertLocationDispatchInfo],
|
||||
):
|
||||
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
||||
router_logits = torch.randn_like(router_logits)
|
||||
if correction_bias is not None:
|
||||
correction_bias = torch.zeros_like(correction_bias)
|
||||
return router_logits, correction_bias
|
||||
|
||||
|
||||
def topk_ids_logical_to_physical(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
@@ -63,7 +75,7 @@ def topk_ids_logical_to_physical(
|
||||
|
||||
if info.ep_dispatch_algorithm == "static":
|
||||
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
||||
if info.ep_dispatch_algorithm == "dynamic":
|
||||
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
|
||||
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user