From b6aa5bbdbf342867695b58e8163638abcba6d486 Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:23:52 +0800 Subject: [PATCH] [0.18.0][BugFix] Add PrefillNoCache state in mla _forward_decode for short prompt (#8264) ### What this PR does / why we need it? This PR is cherry-pick from #8263. This PR aims to fix short prompt problem. The root cause can be found in #8029. Since the previous pr may miss mixed long and short prompt batch, after discussion, we decide to add PrefillNoCache state in mla _forward_decode now instead. Signed-off-by: Zetong Li --- vllm_ascend/attention/mla_v1.py | 1 + vllm_ascend/worker/model_runner_v1.py | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7a704a9b..2db82e9f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1270,6 +1270,7 @@ class AscendMLAImpl(MLAAttentionImpl): AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, + AscendAttentionState.PrefillNoCache, # for extremely short prefills ] and self.speculative_config is not None ): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 84264698..5fc405dd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -838,14 +838,6 @@ class NPUModelRunner(GPUModelRunner): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache - # If all prompts are shorter than or equal to decode threshold, they should - # be treated as SpecDecoding for correct forward path in mla attention backend - if ( - self.speculative_config - and self.speculative_config.method == "mtp" - and np.all(num_scheduled_tokens <= self.decode_threshold) - ): - attn_state = AscendAttentionState.SpecDecoding # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly