update attention nz and mla nz(Improve TPOP 6ms performance) (#909)
### What this PR does / why we need it? Update attention nz and mla nz modules to improve TPOP 6ms performance Convert W_UV and W_UK_T to NPU format in mla_v1.py Convert layer.weight to NPU format in w8a8.py Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -476,9 +476,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
|
||||
@@ -110,5 +110,6 @@ class AscendW8A8LinearMethod:
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
Reference in New Issue
Block a user