support eplb for qwen3 (#6533)
This commit is contained in:
@@ -65,6 +65,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -88,7 +89,7 @@ def fused_topk(
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@@ -355,12 +356,13 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
assert expert_location_dispatch_info is None
|
||||
# Qwen3MOE uses fused_topk
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user