diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a21f392e8..7835c3fa1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -564,18 +564,6 @@ class ModelRunner: if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True - # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention. - # For more details, see: https://github.com/sgl-project/sglang/issues/8616 - elif ( - self.dp_size > 1 - and is_sm100_supported() - and server_args.attention_backend != "triton" - and server_args.attention_backend == "trtllm_mla" - ): - logger.info( - "Disable chunked prefix cache when dp size > 1 and attention backend is not triton." - ) - server_args.disable_chunked_prefix_cache = True if not server_args.disable_chunked_prefix_cache: logger.info("Chunked prefix cache is turned on.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f905851b6..5bd796daa 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1098,7 +1098,6 @@ class DeepseekV2AttentionMLA(nn.Module): attention_backend == "flashinfer" or attention_backend == "flashmla" ) and self.flashinfer_mla_disable_ragged - original_mode = getattr(forward_batch, "_original_forward_mode", None) if ( not disable_ragged and forward_batch.forward_mode.is_extend() @@ -1111,15 +1110,6 @@ class DeepseekV2AttentionMLA(nn.Module): ) or sum_extend_prefix_lens == 0 ) - # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for - # dp case. Redirect to mla kernel as a workaround. - # Tracked by https://github.com/sgl-project/sglang/issues/9806. - and not ( - original_mode is not None - and original_mode.is_decode() - and is_sm100_supported() - and self.current_attention_backend in ("cutlass_mla", "flashinfer") - ) ): return AttnForwardMethod.MHA_CHUNKED_KV else: @@ -1128,14 +1118,6 @@ class DeepseekV2AttentionMLA(nn.Module): # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now return AttnForwardMethod.MHA_CHUNKED_KV elif attention_backend == "trtllm_mla": - original_mode = getattr(forward_batch, "_original_forward_mode", None) - if ( - original_mode is not None - and original_mode.is_decode() - and is_sm100_supported() - ): - return _dispatch_mla_subtype() - sum_extend_prefix_lens = ( sum(forward_batch.extend_prefix_lens_cpu) if forward_batch.extend_prefix_lens_cpu is not None