[Bugfix] Fix scores mask for moe topk (#3705)
This commit is contained in:
@@ -141,7 +141,9 @@ def biased_grouped_topk(
|
|||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
.reshape(num_token, -1)
|
.reshape(num_token, -1)
|
||||||
) # [n, e]
|
) # [n, e]
|
||||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
tmp_scores = scores_for_choice.masked_fill(
|
||||||
|
~score_mask.bool(), float("-inf")
|
||||||
|
) # [n, e]
|
||||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
topk_weights = scores.gather(1, topk_ids)
|
topk_weights = scores.gather(1, topk_ids)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user