diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3aeabc6..9891a02 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -842,7 +842,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _make_attention_mask(self, seq_lens, query_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. - if attn_state == AscendAttentionState.ChunkedPrefill: + if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: return self.attn_mask_builder.get_splitfuse_attn_mask( seq_lens, query_lens, position, self.dtype, self.device) # Prefill without cache situation.