modify:Eliminate redundant operations in the code to improve performance (#137)
### What this PR does / why we need it? Eliminate redundant operations in the code to improve performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: Yaphets24 <d_mym0618@163.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -742,30 +742,20 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.qk_head_dim)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
if attn_metadata.num_prefills > 0:
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
assert attn_metadata.prefill_metadata.seq_lens is not None
|
||||
np_positions = np.concatenate([
|
||||
np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens
|
||||
])
|
||||
positions = torch.tensor(np_positions,
|
||||
device=hidden_states_or_q_c.device)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1
|
||||
positions = torch.tensor(np_positions,
|
||||
device=hidden_states_or_q_c.device)
|
||||
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
@@ -786,16 +776,14 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
||||
k_pe = k_pe.expand(-1, self.num_heads, -1)
|
||||
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
|
||||
dim=2)
|
||||
else:
|
||||
kv_heads_num = self.num_kv_heads
|
||||
q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2),
|
||||
require_contiguous=True)
|
||||
q_nope_t = torch.transpose(q_nope, 0, 1)
|
||||
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
||||
q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2),
|
||||
require_contiguous=True)
|
||||
q_nope = torch.transpose(q_nope_out, 0, 1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
@@ -895,12 +883,10 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
inputLayout=0,
|
||||
outDataType=-1,
|
||||
attnOut=attn_output)
|
||||
attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2),
|
||||
require_contiguous=True)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2),
|
||||
require_contiguous=True)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
|
||||
output, _ = self.o_proj(attn_output.view(num_tokens, -1))
|
||||
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user