From dcd0005058dbd6fd8672378565890cbda924b792 Mon Sep 17 00:00:00 2001 From: HongtaoYang <75939043+SidaoY@users.noreply.github.com> Date: Thu, 6 Mar 2025 09:02:46 +0800 Subject: [PATCH] [Fix] Remove npu_group_topk before CANN version update (#242) Remove npu_group_topk before CANN version update. Signed-off-by: SidaoY <1024863041@qq.com> --- vllm_ascend/ops/fused_moe.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f3e47ab..e216f92 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -50,9 +50,22 @@ def group_topk(hidden_states: torch.Tensor, topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group - torch_npu._npu_group_topk(self=scores, - k=topk_group, - group_num=num_expert_group) + + # TODO: Replace this piece of code to npu_group_topk when CANN and NNAL version is update + num_token = scores.shape[0] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values + group_idx = torch.topk(group_scores.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) + scores = scores.masked_fill(~score_mask.bool(), 0.0) + if e_score_correction_bias is not None: topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights