modify:Eliminate redundant operations in the code to improve performance (#137)

### What this PR does / why we need it?
Eliminate redundant operations in the code to improve performance

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed
---------

Signed-off-by: Yaphets24 <d_mym0618@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Yaphets24
2025-02-22 17:43:42 +08:00
committed by GitHub
parent 202b39a38c
commit d0b3cb4fa7
4 changed files with 52 additions and 29 deletions

View File

@@ -742,30 +742,20 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
if attn_metadata.num_prefills > 0:
assert attn_metadata.prefill_metadata is not None
assert attn_metadata.prefill_metadata.seq_lens is not None
np_positions = np.concatenate([
np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens
])
positions = torch.tensor(np_positions,
device=hidden_states_or_q_c.device)
else:
assert attn_metadata.decode_metadata is not None
np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1
positions = torch.tensor(np_positions,
device=hidden_states_or_q_c.device)
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(num_tokens, -1)
k_pe = k_pe.reshape(num_tokens, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
q_pe = q_pe.view(ori_q_pe_shape)
k_pe = k_pe.view(ori_k_pe_shape)
else:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
if self.w_kc is None or self.w_vc is None:
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
@@ -786,16 +776,14 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
k_cache = torch.cat(
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
dim=2)
k_pe = k_pe.repeat(1, self.num_heads, 1)
k_pe = k_pe.expand(-1, self.num_heads, -1)
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
dim=2)
else:
kv_heads_num = self.num_kv_heads
q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2),
require_contiguous=True)
q_nope_t = torch.transpose(q_nope, 0, 1)
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2),
require_contiguous=True)
q_nope = torch.transpose(q_nope_out, 0, 1)
k_cache = torch.cat(
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
dim=2)
@@ -895,12 +883,10 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
inputLayout=0,
outDataType=-1,
attnOut=attn_output)
attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2),
require_contiguous=True)
attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2),
require_contiguous=True)
attn_output = torch.transpose(attn_output_t, 0, 1)
output, _ = self.o_proj(attn_output.view(num_tokens, -1))
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
return output

View File

@@ -1137,6 +1137,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, virtual_engine):
if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids.to(torch.int32)
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
@@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
down_out_list = torch.cat(down_out_list, dim=0)
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
routing_weights = topk_weights.to(down_out_list.dtype)
hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
skip2=None,
bias=None,
scales=routing_weights,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids)
if len(ori_shape) == 3:

View File

@@ -18,7 +18,8 @@
from typing import Optional, Tuple
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
def rope_forward_oot(
@@ -49,8 +50,43 @@ def rope_forward_oot(
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def rope_deepseek_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
ori_query_shape, ori_key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(query.shape[0], -1)
torch_npu.npu_rope(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
query = query.view(ori_query_shape)
key = key.view(ori_key_shape)
return query, key
RotaryEmbedding.forward_oot = rope_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot