support eplb for qwen3 (#6533)

This commit is contained in:
Yi Zhang
2025-05-24 09:31:30 +08:00
committed by GitHub
parent 7b02c32679
commit e6f113569e
3 changed files with 46 additions and 25 deletions

View File

@@ -65,6 +65,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -88,7 +89,7 @@ def fused_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
return topk_weights, topk_ids
@@ -355,12 +356,13 @@ def select_experts(
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk"
assert expert_location_dispatch_info is None
# Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
assert (