Fix Qwen3MoE missing token padding optimization (#6820)
This commit is contained in:
@@ -66,6 +66,7 @@ def fused_topk(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
@@ -91,6 +92,7 @@ def fused_topk(
|
|||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -363,15 +365,13 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
)
|
)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
assert (
|
|
||||||
num_token_non_padded is None
|
|
||||||
), "num_token_non_padded is not yet supported in fused_topk"
|
|
||||||
# Qwen3MOE uses fused_topk
|
# Qwen3MOE uses fused_topk
|
||||||
topk_weights, topk_ids = fused_topk(
|
topk_weights, topk_ids = fused_topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
),
|
),
|
||||||
@@ -260,6 +261,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user