Waiting for BMM NZ support(Improve TPOP 2ms performance) (#1131)

### What this PR does / why we need it?
W_UV/W_UK_T cannot be converted to nz, because this position will be
fused into transposebatchmatmul, which does not support nz. The weights
are actually converted back to nd in each run.

### Does this PR introduce _any_ user-facing change?
Use #1098 as the baseline, p90 TPOT 90.79ms->88.58ms, improve TPOP 2ms

### How was this patch tested?
use #1101

---------

Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
ttanzhiqiang
2025-06-15 19:57:02 +08:00
committed by GitHub
parent 0d2074a1ec
commit 4270682383

View File

@@ -648,8 +648,10 @@ class AscendMLAImpl(MLAAttentionImpl):
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _compute_prefill_context(
self,