From 20c90be23de7575f1d6a603b2cb6763f1ec903b8 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 28 Mar 2025 18:30:14 -0700 Subject: [PATCH] [Feature] Support FA3 backend for MLA (#4831) --- .../attention/flashattention_backend.py | 244 ++++++++++++------ .../sglang/srt/model_executor/model_runner.py | 6 +- python/sglang/srt/models/deepseek_v2.py | 4 + 3 files changed, 180 insertions(+), 74 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index e1f7ea76f..ac549d5ea 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union import torch +from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: @@ -58,6 +60,9 @@ class FlashAttentionBackend(AttentionBackend): self.decode_cuda_graph_metadata = {} self.req_to_token = model_runner.req_to_token_pool.req_to_token self.page_size = model_runner.page_size + self.use_mla = ( + model_runner.model_config.attention_arch == AttentionArch.MLA + ) and (not global_server_args_dict["disable_mla"]) def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" @@ -117,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) # Use precomputed metadata metadata = self.forward_metadata - # # Use Flash Attention for prefill # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # here is two side inclusive @@ -142,36 +154,72 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) - kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - key_cache, value_cache = kv_cache[0], kv_cache[1] - - key_cache = key_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ) - value_cache = value_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ) page_table = metadata.page_table - o = flash_attn_with_kvcache( - q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - k_cache=key_cache, - v_cache=value_cache, - page_table=page_table, - cache_seqlens=metadata.cache_seqlens_int32, - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=metadata.max_seq_len_q, - softmax_scale=layer.scaling, - causal=True, - window_size=window_size, - softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, - ) + # # Use Flash Attention for prefill + if not self.use_mla: + # Do multi-head attention + kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + key_cache, value_cache = kv_cache[0], kv_cache[1] + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + else: + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def forward_decode( self, @@ -184,24 +232,29 @@ class FlashAttentionBackend(AttentionBackend): ) -> torch.Tensor: """Forward pass with FlashAttention using precomputed metadata.""" # Save KV cache if needed - if k is not None and v is not None and save_kv_cache: - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + if not self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) - # Get KV cache - kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - key_cache, value_cache = kv_cache[0], kv_cache[1] # Use precomputed metadata metadata = self.forward_metadata - # Pre-reshape query tensor - q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # here is two side inclusive @@ -210,33 +263,79 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) - # Run attention with precomputed values - key_cache = key_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ) - value_cache = value_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ) page_table = metadata.page_table - o = flash_attn_with_kvcache( - q=q_reshaped, - k_cache=key_cache, - v_cache=value_cache, - page_table=page_table, - cache_seqlens=metadata.cache_seqlens_int32, - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=layer.scaling, - causal=True, - window_size=window_size, - softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, - ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + if not self.use_mla: + # Do multi-head attention + + # Get KV cache + kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + key_cache, value_cache = kv_cache[0], kv_cache[1] + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + # Pre-reshape query tensor + q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + + # Run attention with precomputed values + o = flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=True, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + else: + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_rope = kv_cache[:, :, layer.v_head_dim :] + c_kv = kv_cache[:, :, : layer.v_head_dim] + k_rope_cache = k_rope.view( + -1, + self.page_size, + layer.tp_k_head_num, + layer.head_dim - layer.v_head_dim, + ) + c_kv_cache = c_kv.view( + -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim + ) + + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=True, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def init_cuda_graph_state(self, max_bs: int): """Initialize CUDA graph state for the attention backend. @@ -286,7 +385,6 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ req_pool_indices, : ] - if forward_mode == ForwardMode.DECODE: # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8e0217277..68e4a9a33 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -230,6 +230,10 @@ class ModelRunner: elif server_args.enable_flashmla: logger.info("MLA optimization is turned on. Use flashmla decode.") server_args.attention_backend = "flashmla" + elif server_args.attention_backend == "fa3": + logger.info( + f"MLA optimization is turned on. Use flash attention 3 backend." + ) else: logger.info("MLA optimization is turned on. Use triton backend.") server_args.attention_backend = "triton" @@ -879,7 +883,7 @@ class ModelRunner: "Please use `--attention-backend flashinfer`." ) logger.warning( - "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported." + "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported." ) from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2133f5320..d6a968f2c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] + self.attention_backend = global_server_args_dict["attention_backend"] self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: @@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module): and not forward_batch.forward_mode.is_draft_extend() and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) + elif self.attention_backend == "fa3": + # Flash Attention: Keep absorbing for all extend/decode + return False else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode return (