fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)

This commit is contained in:
erictanjn
2025-07-28 20:36:46 +08:00
committed by GitHub
parent 45bc170b36
commit a9dd3ec3e9

View File

@@ -397,7 +397,9 @@ def grouped_topk_gpu(
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
_, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts: