fix:reorder topk experts to ensure shared expert replaces minimal score (#8125)
This commit is contained in:
@@ -397,7 +397,9 @@ def grouped_topk_gpu(
|
|||||||
.reshape(num_token, -1)
|
.reshape(num_token, -1)
|
||||||
) # [n, e]
|
) # [n, e]
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
topk_weights, topk_ids = torch.topk(
|
||||||
|
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
||||||
|
)
|
||||||
if num_fused_shared_experts:
|
if num_fused_shared_experts:
|
||||||
topk_ids[:, -1] = torch.randint(
|
topk_ids[:, -1] = torch.randint(
|
||||||
low=num_experts,
|
low=num_experts,
|
||||||
@@ -486,7 +488,9 @@ def biased_grouped_topk_impl(
|
|||||||
tmp_scores = scores_for_choice.masked_fill(
|
tmp_scores = scores_for_choice.masked_fill(
|
||||||
~score_mask.bool(), float("-inf")
|
~score_mask.bool(), float("-inf")
|
||||||
) # [n, e]
|
) # [n, e]
|
||||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
_, topk_ids = torch.topk(
|
||||||
|
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
||||||
|
)
|
||||||
topk_weights = scores.gather(1, topk_ids)
|
topk_weights = scores.gather(1, topk_ids)
|
||||||
|
|
||||||
if num_fused_shared_experts:
|
if num_fused_shared_experts:
|
||||||
|
|||||||
Reference in New Issue
Block a user