Fix chunked prefix cache for nvfp4 (#10180)
Co-authored-by: Elfie Guo <elfieg@nvidia.com>
This commit is contained in:
@@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
disable_ragged = (
|
||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||
) and self.flashinfer_mla_disable_ragged
|
||||
|
||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||
if (
|
||||
not disable_ragged
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
@@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
or sum_extend_prefix_lens == 0
|
||||
)
|
||||
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
||||
# dp case. Redirect to mla kernel as a workaround.
|
||||
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
||||
and not (
|
||||
original_mode is not None
|
||||
and original_mode.is_decode()
|
||||
and is_sm100_supported()
|
||||
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "trtllm_mla":
|
||||
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
||||
if (
|
||||
original_mode is not None
|
||||
and original_mode.is_decode()
|
||||
and is_sm100_supported()
|
||||
):
|
||||
return _dispatch_mla_subtype()
|
||||
|
||||
sum_extend_prefix_lens = (
|
||||
sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.extend_prefix_lens_cpu is not None
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and (
|
||||
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user