[Fix] Remove npu_group_topk before CANN version update (#242)
Remove npu_group_topk before CANN version update. Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
@@ -50,9 +50,22 @@ def group_topk(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
topk_group = 0 if topk_group is None else topk_group
|
topk_group = 0 if topk_group is None else topk_group
|
||||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||||
torch_npu._npu_group_topk(self=scores,
|
|
||||||
|
# TODO: Replace this piece of code to npu_group_topk when CANN and NNAL version is update
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = scores.view(num_token, num_expert_group,
|
||||||
|
-1).max(dim=-1).values
|
||||||
|
group_idx = torch.topk(group_scores.to(torch.float32),
|
||||||
k=topk_group,
|
k=topk_group,
|
||||||
group_num=num_expert_group)
|
dim=-1,
|
||||||
|
sorted=False)[1]
|
||||||
|
group_mask = torch.zeros_like(group_scores)
|
||||||
|
group_mask.scatter_(1, group_idx, 1)
|
||||||
|
score_mask = group_mask.unsqueeze(-1).expand(
|
||||||
|
num_token, num_expert_group,
|
||||||
|
scores.shape[-1] // num_expert_group).reshape(num_token, -1)
|
||||||
|
scores = scores.masked_fill(~score_mask.bool(), 0.0)
|
||||||
|
|
||||||
if e_score_correction_bias is not None:
|
if e_score_correction_bias is not None:
|
||||||
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1]
|
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1]
|
||||||
# Use original unbiased scores for the routing weights
|
# Use original unbiased scores for the routing weights
|
||||||
|
|||||||
Reference in New Issue
Block a user