diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0c600477..bac7345a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -838,6 +838,14 @@ class NPUModelRunner(GPUModelRunner): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache + # If all prompts are shorter than or equal to decode threshold, they should + # be treated as SpecDecoding for correct forward path in mla attention backend + if ( + self.speculative_config + and self.speculative_config.method == "mtp" + and np.all(num_scheduled_tokens <= self.decode_threshold) + ): + attn_state = AscendAttentionState.SpecDecoding # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly