modify:Eliminate redundant operations in the code to improve performance (#137)

### What this PR does / why we need it?
Eliminate redundant operations in the code to improve performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed
---------

Signed-off-by: Yaphets24 <d_mym0618@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Yaphets24
2025-02-22 17:43:42 +08:00
committed by GitHub
parent 202b39a38c
commit d0b3cb4fa7
4 changed files with 52 additions and 29 deletions

View File

@@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids.to(torch.int32)
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
@@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
down_out_list = torch.cat(down_out_list, dim=0)
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
routing_weights = topk_weights.to(down_out_list.dtype)
hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
skip2=None,
bias=None,
scales=routing_weights,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids)
if len(ori_shape) == 3:

View File

@@ -18,7 +18,8 @@
from typing import Optional, Tuple
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
def rope_forward_oot(
@@ -49,8 +50,43 @@ def rope_forward_oot(
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def rope_deepseek_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
ori_query_shape, ori_key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(query.shape[0], -1)
torch_npu.npu_rope(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
query = query.view(ori_query_shape)
key = key.view(ori_key_shape)
return query, key
RotaryEmbedding.forward_oot = rope_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot