fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)

This commit is contained in:
erictanjn
2025-07-28 20:36:46 +08:00
committed by GitHub
parent 45bc170b36
commit a9dd3ec3e9

View File

@@ -397,7 +397,9 @@ def grouped_topk_gpu(
.reshape(num_token, -1) .reshape(num_token, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
if num_fused_shared_experts: if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint( topk_ids[:, -1] = torch.randint(
low=num_experts, low=num_experts,
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
tmp_scores = scores_for_choice.masked_fill( tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf") ~score_mask.bool(), float("-inf")
) # [n, e] ) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) _, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
)
topk_weights = scores.gather(1, topk_ids) topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts: if num_fused_shared_experts: