update attention nz and mla nz(Improve TPOP 6ms performance) (#909)
### What this PR does / why we need it? Update attention nz and mla nz modules to improve TPOP 6ms performance Convert W_UV and W_UK_T to NPU format in mla_v1.py Convert layer.weight to NPU format in w8a8.py Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -476,9 +476,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user