diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f3e47ab..e216f92 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -50,9 +50,22 @@ def group_topk(hidden_states: torch.Tensor, topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group - torch_npu._npu_group_topk(self=scores, - k=topk_group, - group_num=num_expert_group) + + # TODO: Replace this piece of code to npu_group_topk when CANN and NNAL version is update + num_token = scores.shape[0] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values + group_idx = torch.topk(group_scores.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) + scores = scores.masked_fill(~score_mask.bool(), 0.0) + if e_score_correction_bias is not None: topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights