diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a11cecc..aa6c597 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -789,7 +789,7 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 350a6c4..f641f0d 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -1006,7 +1006,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight,