[0.18.0][BugFix] Fix attention state of short prompt for correct forwarding (#8088)

### What this PR does / why we need it?
This PR is cherry-pick from #8029.

This PR aims to fix attention state of short prompt for correct
forwarding. Since a batch of short prompts (prefill tokens less than or
equal to num_spec_tokens + 1) will be treated as decode requests (by
split_decodes_and_prefills), its original PrefillNoCache attention state
contradicts. Thus these short prompts will be passed into a mismatched
branch and incur errors.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-04-09 21:21:24 +08:00
committed by GitHub
parent f668ff9ef0
commit 054fde7b72

View File

@@ -838,6 +838,14 @@ class NPUModelRunner(GPUModelRunner):
def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
attn_state = AscendAttentionState.PrefillNoCache
# If all prompts are shorter than or equal to decode threshold, they should
# be treated as SpecDecoding for correct forward path in mla attention backend
if (
self.speculative_config
and self.speculative_config.method == "mtp"
and np.all(num_scheduled_tokens <= self.decode_threshold)
):
attn_state = AscendAttentionState.SpecDecoding
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly