enable prefix cache with dp (#10459)

This commit is contained in:
Shu Wang
2025-09-16 20:26:58 -05:00
committed by GitHub
parent e1d45bc280
commit 124097fc5b
2 changed files with 0 additions and 30 deletions

View File

@@ -1098,7 +1098,6 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend == "flashinfer" or attention_backend == "flashmla"
) and self.flashinfer_mla_disable_ragged
original_mode = getattr(forward_batch, "_original_forward_mode", None)
if (
not disable_ragged
and forward_batch.forward_mode.is_extend()
@@ -1111,15 +1110,6 @@ class DeepseekV2AttentionMLA(nn.Module):
)
or sum_extend_prefix_lens == 0
)
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
# dp case. Redirect to mla kernel as a workaround.
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
and not (
original_mode is not None
and original_mode.is_decode()
and is_sm100_supported()
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
)
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
@@ -1128,14 +1118,6 @@ class DeepseekV2AttentionMLA(nn.Module):
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
elif attention_backend == "trtllm_mla":
original_mode = getattr(forward_batch, "_original_forward_mode", None)
if (
original_mode is not None
and original_mode.is_decode()
and is_sm100_supported()
):
return _dispatch_mla_subtype()
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None