modify:Eliminate redundant operations in the code to improve performance (#137)
### What this PR does / why we need it? Eliminate redundant operations in the code to improve performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: Yaphets24 <d_mym0618@163.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor,
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
@@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
routing_weights = topk_weights.to(down_out_list.dtype)
|
||||
hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out_list,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=routing_weights,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids)
|
||||
if len(ori_shape) == 3:
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
@@ -49,8 +50,43 @@ def rope_forward_oot(
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_deepseek_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
import torch_npu
|
||||
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
ori_query_shape, ori_key_shape = query.shape, key.shape
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(query.shape[0], -1)
|
||||
torch_npu.npu_rope(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
query = query.view(ori_query_shape)
|
||||
key = key.view(ori_key_shape)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
||||
|
||||
Reference in New Issue
Block a user