diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 391627c7a..95b962fa3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -999,6 +999,8 @@ class DeepseekV2AttentionMLA(nn.Module): attention_backend == "flashinfer" or attention_backend == "fa3" or attention_backend == "flashmla" + or attention_backend == "trtllm_mla" + or attention_backend == "cutlass_mla" ): # Use MHA with chunked KV cache when prefilling on long sequences. sum_extend_prefix_lens = (