[0.18.0][BugFix] Add PrefillNoCache state in mla _forward_decode for short prompt (#8264)
### What this PR does / why we need it? This PR is cherry-pick from #8263. This PR aims to fix short prompt problem. The root cause can be found in #8029. Since the previous pr may miss mixed long and short prompt batch, after discussion, we decide to add PrefillNoCache state in mla _forward_decode now instead. Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -1270,6 +1270,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
AscendAttentionState.SpecDecoding,
|
AscendAttentionState.SpecDecoding,
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.DecodeOnly,
|
AscendAttentionState.DecodeOnly,
|
||||||
|
AscendAttentionState.PrefillNoCache, # for extremely short prefills
|
||||||
]
|
]
|
||||||
and self.speculative_config is not None
|
and self.speculative_config is not None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -838,14 +838,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
|
||||||
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
|
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
|
||||||
attn_state = AscendAttentionState.PrefillNoCache
|
attn_state = AscendAttentionState.PrefillNoCache
|
||||||
# If all prompts are shorter than or equal to decode threshold, they should
|
|
||||||
# be treated as SpecDecoding for correct forward path in mla attention backend
|
|
||||||
if (
|
|
||||||
self.speculative_config
|
|
||||||
and self.speculative_config.method == "mtp"
|
|
||||||
and np.all(num_scheduled_tokens <= self.decode_threshold)
|
|
||||||
):
|
|
||||||
attn_state = AscendAttentionState.SpecDecoding
|
|
||||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||||
elif np.all(num_scheduled_tokens == 1):
|
elif np.all(num_scheduled_tokens == 1):
|
||||||
attn_state = AscendAttentionState.DecodeOnly
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
|||||||
Reference in New Issue
Block a user