update attention nz and mla nz(Improve TPOP 6ms performance) (#909)

### What this PR does / why we need it?
Update attention nz and mla nz modules to improve TPOP 6ms performance
Convert W_UV and W_UK_T to NPU format in mla_v1.py
Convert layer.weight to NPU format in w8a8.py

Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
ttanzhiqiang
2025-05-23 10:18:10 +08:00
committed by GitHub
parent 7153d8890b
commit dc6172efd3
2 changed files with 5 additions and 2 deletions

View File

@@ -476,9 +476,11 @@ class AscendMLAImpl(MLAAttentionImpl):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _forward_prefill(
self,

View File

@@ -110,5 +110,6 @@ class AscendW8A8LinearMethod:
requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)