[0.18.0][BugFix] Add PrefillNoCache state in mla _forward_decode for short prompt (#8264)
### What this PR does / why we need it? This PR is cherry-pick from #8263. This PR aims to fix short prompt problem. The root cause can be found in #8029. Since the previous pr may miss mixed long and short prompt batch, after discussion, we decide to add PrefillNoCache state in mla _forward_decode now instead. Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -838,14 +838,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
|
||||
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# If all prompts are shorter than or equal to decode threshold, they should
|
||||
# be treated as SpecDecoding for correct forward path in mla attention backend
|
||||
if (
|
||||
self.speculative_config
|
||||
and self.speculative_config.method == "mtp"
|
||||
and np.all(num_scheduled_tokens <= self.decode_threshold)
|
||||
):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
|
||||
Reference in New Issue
Block a user