diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7a704a9b..2db82e9f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1270,6 +1270,7 @@ class AscendMLAImpl(MLAAttentionImpl): AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, + AscendAttentionState.PrefillNoCache, # for extremely short prefills ] and self.speculative_config is not None ): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 84264698..5fc405dd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -838,14 +838,6 @@ class NPUModelRunner(GPUModelRunner): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache - # If all prompts are shorter than or equal to decode threshold, they should - # be treated as SpecDecoding for correct forward path in mla attention backend - if ( - self.speculative_config - and self.speculative_config.method == "mtp" - and np.all(num_scheduled_tokens <= self.decode_threshold) - ): - attn_state = AscendAttentionState.SpecDecoding # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly