[Feature]Support ragged prefill in flashinfer mla backend (#3967)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
Baizhou Zhang
2025-02-28 18:13:56 -08:00
committed by GitHub
parent f3b99f73b3
commit 90a4b7d98a
9 changed files with 308 additions and 407 deletions

View File

@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
global_server_args_dict["disable_radix_cache"]
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode