From dc6172efd3860ce95b40a7b3e93611f875f06d40 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <38750855+ttanzhiqiang@users.noreply.github.com> Date: Fri, 23 May 2025 10:18:10 +0800 Subject: [PATCH] update attention nz and mla nz(Improve TPOP 6ms performance) (#909) ### What this PR does / why we need it? Update attention nz and mla nz modules to improve TPOP 6ms performance Convert W_UV and W_UK_T to NPU format in mla_v1.py Convert layer.weight to NPU format in w8a8.py Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/attention/mla_v1.py | 6 ++++-- vllm_ascend/quantization/w8a8.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2d522e4..fabf95e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -476,9 +476,11 @@ class AscendMLAImpl(MLAAttentionImpl): [self.qk_nope_head_dim, self.v_head_dim], dim=-1) # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) + self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) def _forward_prefill( self, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index f740d8f..db23cb0 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -110,5 +110,6 @@ class AscendW8A8LinearMethod: requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data)