modify:Eliminate redundant operations in the code to improve performance (#137)

### What this PR does / why we need it?
Eliminate redundant operations in the code to improve performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed
---------

Signed-off-by: Yaphets24 <d_mym0618@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Yaphets24
2025-02-22 17:43:42 +08:00
committed by GitHub
parent 202b39a38c
commit d0b3cb4fa7
4 changed files with 52 additions and 29 deletions

View File

@@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids.to(torch.int32)
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
@@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
down_out_list = torch.cat(down_out_list, dim=0)
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
routing_weights = topk_weights.to(down_out_list.dtype)
hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
skip2=None,
bias=None,
scales=routing_weights,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids)
if len(ori_shape) == 3: