[Feature]Support ragged prefill in flashinfer mla backend (#3967)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
def no_absorb() -> bool:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
global_server_args_dict["disable_radix_cache"]
|
||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
|
||||
Reference in New Issue
Block a user