fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)
This commit is contained in:
@@ -397,7 +397,9 @@ def grouped_topk_gpu(
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
||||
)
|
||||
if num_fused_shared_experts:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
|
||||
tmp_scores = scores_for_choice.masked_fill(
|
||||
~score_mask.bool(), float("-inf")
|
||||
) # [n, e]
|
||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
_, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
||||
)
|
||||
topk_weights = scores.gather(1, topk_ids)
|
||||
|
||||
if num_fused_shared_experts:
|
||||
|
||||
Reference in New Issue
Block a user