Fix FA3 DeepSeek prefill performance regression (#5624)
Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
@@ -583,13 +583,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
elif self.attention_backend == "fa3":
|
elif self.attention_backend == "fa3":
|
||||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||||
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
||||||
|
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not self.disable_chunked_prefix_cache
|
and not self.disable_chunked_prefix_cache
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
and sum(forward_batch.extend_prefix_lens_cpu)
|
and (
|
||||||
>= self.chunked_prefix_cache_threshold
|
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
||||||
|
or sum_extend_prefix_lens == 0
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user