From 36acd2ff16d9f0ed888691b99179f9c1531794a6 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Fri, 12 Sep 2025 05:20:30 -0500 Subject: [PATCH] Fix chunked prefix cache for nvfp4 (#10180) Co-authored-by: Elfie Guo --- .../layers/attention/trtllm_mla_backend.py | 20 +++++++++++++- .../sglang/srt/model_executor/model_runner.py | 3 ++- python/sglang/srt/models/deepseek_v2.py | 27 +++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index b8d62c3fa..9b6309d4c 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import ( create_flashmla_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -72,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): kv_indptr_buf: Optional[torch.Tensor] = None, q_indptr_decode_buf: Optional[torch.Tensor] = None, ): - super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) + super().__init__( + model_runner, + skip_prefill, + kv_indptr_buf, + q_indptr_decode_buf, + ) config = model_runner.model_config @@ -112,6 +118,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None + self.disable_chunked_prefix_cache = global_server_args_dict[ + "disable_chunked_prefix_cache" + ] + def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -301,6 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() ): + if self.disable_chunked_prefix_cache: + super().init_forward_metadata(forward_batch) + seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens cum_seq_lens_q = torch.cat( ( @@ -540,6 +553,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) + # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel + if forward_batch.attn_attend_prefix_cache is None: + return super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) if not forward_batch.attn_attend_prefix_cache: q = q.view(-1, layer.tp_q_head_num, layer.head_dim) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 39d1ab5fd..8960b0bc8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -560,18 +560,19 @@ class ModelRunner: if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True + # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention. # For more details, see: https://github.com/sgl-project/sglang/issues/8616 elif ( self.dp_size > 1 and is_sm100_supported() and server_args.attention_backend != "triton" + and server_args.attention_backend == "trtllm_mla" ): logger.info( "Disable chunked prefix cache when dp size > 1 and attention backend is not triton." ) server_args.disable_chunked_prefix_cache = True - if not server_args.disable_chunked_prefix_cache: logger.info("Chunked prefix cache is turned on.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5452e7e8c..c2ceac39a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module): disable_ragged = ( attention_backend == "flashinfer" or attention_backend == "flashmla" ) and self.flashinfer_mla_disable_ragged + + original_mode = getattr(forward_batch, "_original_forward_mode", None) if ( not disable_ragged and forward_batch.forward_mode.is_extend() @@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module): ) or sum_extend_prefix_lens == 0 ) + # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for + # dp case. Redirect to mla kernel as a workaround. + # Tracked by https://github.com/sgl-project/sglang/issues/9806. + and not ( + original_mode is not None + and original_mode.is_decode() + and is_sm100_supported() + and self.current_attention_backend in ("cutlass_mla", "flashinfer") + ) ): return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype() elif attention_backend == "trtllm_mla": + original_mode = getattr(forward_batch, "_original_forward_mode", None) + if ( + original_mode is not None + and original_mode.is_decode() + and is_sm100_supported() + ): + return _dispatch_mla_subtype() + + sum_extend_prefix_lens = ( + sum(forward_batch.extend_prefix_lens_cpu) + if forward_batch.extend_prefix_lens_cpu is not None + else 0 + ) if ( forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() + and ( + not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0 + ) ): return AttnForwardMethod.MHA_CHUNKED_KV else: