From 9708d353b756563107e346081298a142fabd584f Mon Sep 17 00:00:00 2001 From: Yongfei Xu Date: Fri, 22 Aug 2025 09:19:44 +0800 Subject: [PATCH] Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616) Co-authored-by: xuyongfei.xyf --- .../attention/flashattention_backend.py | 18 ++- .../attention/flashinfer_mla_backend.py | 142 +++++++++++++++++- python/sglang/srt/managers/schedule_batch.py | 1 + .../srt/model_executor/forward_batch_info.py | 3 + .../sglang/srt/model_executor/model_runner.py | 3 - python/sglang/srt/models/deepseek_v2.py | 92 ++++-------- 6 files changed, 184 insertions(+), 75 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 50e952e22..3bdf7c7c2 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend): o = result else: if ( - not global_server_args_dict["disable_chunked_prefix_cache"] - and forward_batch.attn_attend_prefix_cache is not None + forward_batch.attn_attend_prefix_cache is not None and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() ): # Do multi-head attention with chunked prefix cache - if forward_batch.attn_attend_prefix_cache: + assert not global_server_args_dict["disable_chunked_prefix_cache"] # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None @@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend): chunk_idx = forward_batch.prefix_chunk_idx assert chunk_idx >= 0 - output, lse, *rest = flash_attn_varlen_func( + assert forward_batch.mha_return_lse + output = flash_attn_varlen_func( q=q.view(-1, layer.tp_q_head_num, layer.head_dim), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), @@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend): ) else: # MHA for extend part of sequence without attending prefix kv cache - output, lse, *rest = flash_attn_varlen_func( + output = flash_attn_varlen_func( q=q.view(-1, layer.tp_q_head_num, layer.head_dim), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), @@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend): max_seqlen_k=metadata.max_seq_len_q, softmax_scale=layer.scaling, causal=True, - return_softmax_lse=True, + return_softmax_lse=forward_batch.mha_return_lse, ) - return output, lse + if forward_batch.mha_return_lse: + output, lse, *rest = output + lse = torch.transpose(lse, 0, 1).contiguous() + return output, lse + return output else: # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index fb476a762..a295cc906 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -59,6 +59,115 @@ class PrefillMetadata: global_workspace_buffer = None +class FlashInferMhaChunkKVRunner: + def __init__( + self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend" + ): + # Parse Constants + self.num_local_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.v_head_dim = model_runner.model_config.v_head_dim + self.data_type = model_runner.dtype + self.q_data_type = model_runner.dtype + + # Buffers and wrappers + self.qo_indptr = attn_backend.qo_indptr + self.workspace_buffer = attn_backend.workspace_buffer + self.fmha_backend = attn_backend.fmha_backend + + self.chunk_ragged_wrappers = [] + self.ragged_wrapper = attn_backend.prefill_wrapper_ragged + + def update_prefix_chunks(self, num_prefix_chunks: int): + while num_prefix_chunks > len(self.chunk_ragged_wrappers): + ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD", backend=self.fmha_backend + ) + self.chunk_ragged_wrappers.append(ragged_wrapper) + + def update_wrapper( + self, + forward_batch: ForwardBatch, + ): + assert forward_batch.num_prefix_chunks is not None + num_prefix_chunks = forward_batch.num_prefix_chunks + self.update_prefix_chunks(num_prefix_chunks) + + prefix_lens = forward_batch.extend_prefix_lens + seq_lens = forward_batch.seq_lens + + bs = len(seq_lens) + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + + for chunk_idx in range(forward_batch.num_prefix_chunks): + # MHA for chunked prefix kv cache when running model with MLA + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert forward_batch.prefix_chunk_max_seq_lens is not None + + kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx] + wrapper = self.chunk_ragged_wrappers[chunk_idx] + wrapper.begin_forward( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=self.num_local_heads, + num_kv_heads=self.num_local_heads, + head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, + head_dim_vo=self.v_head_dim, + q_data_type=self.q_data_type, + causal=False, + ) + # ragged prefill + self.ragged_wrapper.begin_forward( + qo_indptr=qo_indptr, + kv_indptr=qo_indptr, + num_qo_heads=self.num_local_heads, + num_kv_heads=self.num_local_heads, + head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, + head_dim_vo=self.v_head_dim, + q_data_type=self.q_data_type, + causal=True, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + ): + logits_soft_cap = layer.logit_cap + if forward_batch.attn_attend_prefix_cache: + chunk_idx = forward_batch.prefix_chunk_idx + assert chunk_idx >= 0 + wrapper = self.chunk_ragged_wrappers[chunk_idx] + o1, s1 = wrapper.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + else: + o1, s1 = self.ragged_wrapper.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + + return o1, s1 + + class FlashInferMLAAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.skip_prefill = skip_prefill + self.enable_chunk_kv = ( + not skip_prefill + and global_server_args_dict["disaggregation_mode"] != "decode" + and not global_server_args_dict["disable_chunked_prefix_cache"] + and not global_server_args_dict["flashinfer_mla_disable_ragged"] + ) self.page_size = model_runner.page_size # Allocate buffers @@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): else: self.q_indptr_decode = q_indptr_decode_buf - fmha_backend = "auto" + self.fmha_backend = "auto" if is_sm100_supported(): - fmha_backend = "cutlass" + self.fmha_backend = "cutlass" self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - self.workspace_buffer, "NHD", backend=fmha_backend + self.workspace_buffer, "NHD", backend=self.fmha_backend ) if not self.skip_prefill: @@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( model_runner, self ) + if self.enable_chunk_kv: + self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self) self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode( model_runner, self @@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 1 + def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + self.mha_chunk_kv_cache.update_wrapper(forward_batch) + def forward_extend( self, q: torch.Tensor, @@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): + if ( + forward_batch.attn_attend_prefix_cache is not None + and forward_batch.mha_return_lse + ): # MHA Chunk + assert self.enable_chunk_kv + assert q_rope is None + assert k_rope is None + o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch) + return o1, s1 + cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap prefill_wrapper_paged = self.forward_metadata.prefill_wrapper @@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): k = torch.cat([k, k_rope], dim=-1) o = self.prefill_wrapper_ragged.forward( qall, - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_k_head_num, layer.v_head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), + v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), causal=True, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, @@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill: head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, head_dim_vo=self.v_head_dim, q_data_type=self.q_data_type, + causal=True, ) else: # mla paged prefill diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5b45154db..95ec32999 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_symm_mem", "quantization", "enable_custom_logit_processor", + "disaggregation_mode", ] # Put some global args for easy access diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bceb0759e..65c0a07f8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -241,6 +241,9 @@ class ForwardBatch: prefix_chunk_num_tokens: Optional[List[int]] = None # KV Indices for each chunk prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None + # For MLA chunked prefix cache used in chunked prefill + # Tell attention backend whether lse needs to be returned + mha_return_lse: Optional[bool] = None # For multimodal mm_inputs: Optional[List[MultimodalInputs]] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c43c502da..acfeaee3d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -518,9 +518,6 @@ class ModelRunner: if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True - elif self.page_size > 1: - logger.info("Disable chunked prefix cache when page size > 1.") - server_args.disable_chunked_prefix_cache = True if not server_args.disable_chunked_prefix_cache: logger.info("Chunked prefix cache is turned on.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 434cec4b1..391627c7a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module): if attention_backend == "ascend": return AttnForwardMethod.MLA - elif attention_backend == "flashinfer": + elif ( + attention_backend == "flashinfer" + or attention_backend == "fa3" + or attention_backend == "flashmla" + ): + # Use MHA with chunked KV cache when prefilling on long sequences. + sum_extend_prefix_lens = ( + sum(forward_batch.extend_prefix_lens_cpu) + if forward_batch.extend_prefix_lens_cpu is not None + else 0 + ) # Flashinfer MLA: Do not absorb when enabling ragged prefill + disable_ragged = ( + attention_backend == "flashinfer" or attention_backend == "flashmla" + ) and self.flashinfer_mla_disable_ragged if ( - not self.flashinfer_mla_disable_ragged + not disable_ragged and forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() - and sum(forward_batch.extend_prefix_lens_cpu) == 0 - ): - return AttnForwardMethod.MHA - else: - return _dispatch_mla_subtype() - elif attention_backend == "fa3": - # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. - if forward_batch.extend_prefix_lens_cpu is not None: - sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu) - if ( - forward_batch.forward_mode.is_extend() - and not self.disable_chunked_prefix_cache - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() and ( - sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold + ( + sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold + and not self.disable_chunked_prefix_cache + ) or sum_extend_prefix_lens == 0 ) ): @@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module): k[..., self.qk_nope_head_dim :] = k_pe output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) - lse = torch.transpose(lse, 0, 1).contiguous() tmp_output = torch.empty_like(accum_output) tmp_lse = torch.empty_like(accum_lse) merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse) @@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module): # will be helpful for understanding the purpose of this function. # First do normal mha forward to get output for extended part - if self.q_lora_rank is not None: - q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 - ) - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) - else: - q = self.q_proj(hidden_states)[0].view( - -1, self.num_local_heads, self.qk_head_dim - ) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] - k_pe = latent_cache[:, :, self.kv_lora_rank :] - - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim :] = q_pe - k = torch.empty_like(q) - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe - - latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe - - # Save latent cache - forward_batch.token_to_kv_pool.set_kv_buffer( - self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + return self.forward_normal_prepare( + positions, hidden_states, forward_batch, zero_allocator ) - return q, k, v, forward_batch - def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): + has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu) + # Only initialize the info once + if has_extend_prefix and forward_batch.num_prefix_chunks is None: + forward_batch.prepare_chunked_prefix_cache_info(q.device) + if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"): + forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch) + + forward_batch.mha_return_lse = has_extend_prefix # Do mha for extended part without prefix forward_batch.set_attn_attend_prefix_cache(False) - attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) - lse = torch.transpose(lse, 0, 1).contiguous() + attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) # Do mha attention with chunked prefix cache if there are any sequence with prefix - if any(forward_batch.extend_prefix_lens_cpu): - # Only initialize the info once - if forward_batch.num_prefix_chunks is None: - forward_batch.prepare_chunked_prefix_cache_info(q.device) - + if has_extend_prefix: + attn_output, lse = attn_output forward_batch.set_attn_attend_prefix_cache(True) attn_output = self._chunked_prefix_attn_mha( q=q,