From 5d7edc8e55d7b6e28b1b765bbdd639d801463815 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Sun, 23 Mar 2025 23:28:11 -0700 Subject: [PATCH] Support FA3 as Attention backend by using `--attention-backend fa3` (#4680) Co-authored-by: qsong Co-authored-by: qingquansong --- python/sglang/bench_serving.py | 2 + .../attention/flashattention_backend.py | 295 +++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 13 + python/sglang/srt/server_args.py | 2 +- .../test/attention/test_flashattn_backend.py | 311 ++++++++++++++++++ 5 files changed, 622 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/layers/attention/flashattention_backend.py create mode 100644 python/sglang/test/attention/test_flashattn_backend.py diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 9887c7161..e9ab9830e 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -501,6 +501,7 @@ def get_dataset(args, tokenizer): question_len=args.gsp_question_len, output_len=args.gsp_output_len, tokenizer=tokenizer, + args=args, ) else: raise ValueError(f"Unknown dataset: {args.dataset_name}") @@ -788,6 +789,7 @@ def sample_generated_shared_prefix_requests( question_len: int, output_len: int, tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace, ) -> List[Tuple[str, int, int]]: """Generate benchmark requests with shared system prompts using random tokens and caching.""" cache_path = get_gen_prefix_cache_path(args, tokenizer) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py new file mode 100644 index 000000000..d773cbf59 --- /dev/null +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + +""" +Support different attention backends. +Now there are three backends: FlashInfer, Triton and FlashAttention. +Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + +from flash_attn_interface import flash_attn_with_kvcache + + +@dataclass +class FlashAttentionMetadata: + """Metadata for decode operations to avoid redundant computations.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + max_seq_len_k: int = 0 + window_size: tuple = (-1, -1) + page_table: torch.Tensor = None + cache_seqlens_int32: torch.Tensor = None + max_seq_len_q: int = 0 + + +class FlashAttentionBackend(AttentionBackend): + """FlashAttention backend implementation.""" + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + ): + super().__init__() + + assert not ( + model_runner.sliding_window_size is not None + and model_runner.model_config.is_encoder_decoder + ), "Sliding window and cross attention are not supported together" + + # Initialize metadata + self.forward_metadata: FlashAttentionMetadata = None + self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device + self.decode_cuda_graph_metadata = {} + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Initialize forward metadata to cache repetitive calculations.""" + # Create metadata based on forward mode + metadata = FlashAttentionMetadata() + + extend_seq_lens = forward_batch.extend_seq_lens + # Get sequence information + seqlens_in_batch = forward_batch.seq_lens + # Precompute int32 version of sequence lengths + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + batch_size = len(seqlens_in_batch) + device = seqlens_in_batch.device + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = seqlens_in_batch.max().item() + # Precompute page table + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + if forward_batch.forward_mode == ForwardMode.DECODE: + # Precompute cumulative sequence lengths + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + else: + extend_no_prefix = not any(forward_batch.extend_prefix_lens) + # Precompute cumulative sequence lengths + if not extend_no_prefix: + metadata.cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + else: + metadata.cu_seqlens_q = metadata.cu_seqlens_k + metadata.max_seq_len_q = seqlens_in_batch.max().item() + self.forward_metadata = metadata + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + # Use precomputed metadata + metadata = self.forward_metadata + + # # Use Flash Attention for prefill + # Calculate window size (can be moved to metadata if layer properties don't change) + # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 + # here is two side inclusive + window_size = ( + (layer.sliding_window_size, 0) + if layer.sliding_window_size is not None + else (-1, -1) + ) + kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + key_cache, value_cache = kv_cache[0], kv_cache[1] + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache.unsqueeze(1), + v_cache=value_cache.unsqueeze(1), + page_table=metadata.page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=metadata.max_seq_len_q, + softmax_scale=layer.scaling, + causal=True, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ) -> torch.Tensor: + """Forward pass with FlashAttention using precomputed metadata.""" + # Save KV cache if needed + if k is not None and v is not None and save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + # Get KV cache + kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + key_cache, value_cache = kv_cache[0], kv_cache[1] + + # Use precomputed metadata + metadata = self.forward_metadata + + # Pre-reshape query tensor + q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + + # Calculate window size (can be moved to metadata if layer properties don't change) + # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 + # here is two side inclusive + window_size = ( + (layer.sliding_window_size, 0) + if layer.sliding_window_size is not None + else (-1, -1) + ) + # Run attention with precomputed values + o = flash_attn_with_kvcache( + q=q_reshaped, + k_cache=key_cache.unsqueeze(1), + v_cache=value_cache.unsqueeze(1), + page_table=metadata.page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=True, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=layer.k_scale, + v_descale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def init_cuda_graph_state(self, max_bs: int): + """Initialize CUDA graph state for the attention backend. + + Args: + max_bs (int): Maximum batch size to support in CUDA graphs + + This creates fixed-size tensors that will be reused during CUDA graph replay + to avoid memory allocations. + """ + # Initialize fixed size tensors for decode operations + self.decode_cuda_graph_metadata = { + # Page table for token mapping (batch_size, max_context_len) + "page_table": torch.zeros( + max_bs, self.max_context_len, dtype=torch.int32, device=self.device + ), + } + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + """Initialize forward metadata for capturing CUDA graph.""" + metadata = FlashAttentionMetadata() + # Get sequence information + metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + batch_size = len(seq_lens) + device = seq_lens.device + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = seq_lens.max().item() + # Precompute page table + metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ + req_pool_indices, : + ] + if forward_mode == ForwardMode.DECODE: + # Precompute cumulative sequence lengths + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + else: + raise ValueError("Do not support Prefill Mode cuda graph") + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + # """Initialize forward metadata for replaying CUDA graph.""" + seqlens_in_batch = seq_lens[:bs] + metadata = self.decode_cuda_graph_metadata[bs] + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = seqlens_in_batch.max().item() + # Only zero out the part out of max_len_k + metadata.page_table[:, metadata.max_seq_len_k :].fill_(0) + # Then do the copy + metadata.page_table[:, : metadata.max_seq_len_k].copy_( + self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k] + ) + self.forward_decode_metadata = metadata + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for sequence length in CUDA graph.""" + return 0 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6ae2af0df..7bee2cb8a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -868,6 +868,19 @@ class ModelRunner: from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend self.attn_backend = FlashMLABackend(self) + elif self.server_args.attention_backend == "fa3": + assert torch.cuda.get_device_capability()[0] >= 9, ( + "FlashAttention v3 Backend requires SM>=90. " + "Please use `--attention-backend flashinfer`." + ) + logger.warning( + "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported." + ) + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + self.attn_backend = FlashAttentionBackend(self) else: raise ValueError( f"Invalid attention backend: {self.server_args.attention_backend}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8ad0e4b1a..241682d4e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -770,7 +770,7 @@ class ServerArgs: parser.add_argument( "--attention-backend", type=str, - choices=["flashinfer", "triton", "torch_native"], + choices=["flashinfer", "triton", "torch_native", "fa3"], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) diff --git a/python/sglang/test/attention/test_flashattn_backend.py b/python/sglang/test/attention/test_flashattn_backend.py new file mode 100644 index 000000000..4c37a8758 --- /dev/null +++ b/python/sglang/test/attention/test_flashattn_backend.py @@ -0,0 +1,311 @@ +import unittest + +import torch + +from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + + +class MockModelRunner: + model_config = type( + "ModelConfig", (), {"context_len": 2048, "is_multimodal": False} + ) + sliding_window_size = None + + def __init__(self, device="cuda"): + self.device = device + # Create a proper req_to_token_pool with the req_to_token attribute + self.req_to_token_pool = type( + "TokenPool", + (), + { + "size": 160, # a typical max_bs * max_context_len for cuda graph decode + "req_to_token": torch.zeros( + 160, 2048, dtype=torch.int32, device=device + ), # Add req_to_token attribute + }, + ) + + +class MockReqToTokenPool: + def __init__(self, batch_size, seq_len, device): + self.req_to_token = ( + torch.arange(batch_size * seq_len, device=device) + .reshape(batch_size, seq_len) + .to(torch.int32) + ) + + +@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") +class TestFlashAttentionBackend(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + self.model_runner = MockModelRunner() + self.backend = FlashAttentionBackend(self.model_runner) + + # Common test parameters + self.batch_size = 2 + self.seq_len = 4 + self.num_heads = 2 + self.head_dim = 8 + self.device = "cuda" + self.dtype = torch.float16 + + def _create_attention_layer(self): + """Helper method to create an attention layer.""" + return RadixAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + scaling=1.0, + num_kv_heads=self.num_heads, + layer_id=0, + ) + + def _create_kv_pool(self, size): + """Helper method to create a KV pool.""" + return MHATokenToKVPool( + size=size, + page_size=1, # only consider page=1 for unit test + dtype=self.dtype, + head_num=self.num_heads, + head_dim=self.head_dim, + layer_num=1, # only consider layer=1 for unit test + device=self.device, + enable_memory_saver=False, + ) + + def _create_qkv_tensors(self, tokens_len): + """Helper method to create q, k, v tensors.""" + return ( + torch.randn( + tokens_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + device=self.device, + ), + torch.randn( + tokens_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + device=self.device, + ), + torch.randn( + tokens_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + device=self.device, + ), + ) + + def _verify_output(self, output, expected_shape): + """Helper method to verify output.""" + self.assertEqual( + output.shape, + expected_shape, + f"Expected shape {expected_shape}, got {output.shape}", + ) + self.assertEqual(output.dtype, self.dtype) + self.assertEqual(output.device.type, "cuda") + self.assertEqual( + torch.isnan(output).sum().item(), 0, "Output contains NaN values" + ) + + def test_forward_extend(self): + """Test the standard extend operation.""" + # Create test inputs + q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len) + + # Create attention layer + layer = self._create_attention_layer() + + # Create forward batch + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, self.seq_len), device=self.device + ), + out_cache_loc=torch.arange( + self.batch_size * self.seq_len, device=self.device + ), + seq_lens_sum=self.batch_size * self.seq_len, + forward_mode=ForwardMode.EXTEND, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device), + # 0 prefix, 4 extend + extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device), + extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device), + attn_backend=self.backend, + ) + + # Add token pool and KV cache + forward_batch.req_to_token_pool = MockReqToTokenPool( + self.batch_size, self.seq_len, self.device + ) + forward_batch.token_to_kv_pool = self._create_kv_pool( + self.batch_size * self.seq_len + ) + + # Initialize forward metadata before running the attention + self.backend.init_forward_metadata(forward_batch) + + # Run forward_extend + output = self.backend.forward_extend(q, k, v, layer, forward_batch) + + # Verify output + expected_shape = ( + self.batch_size * self.seq_len, + self.num_heads * self.head_dim, + ) + self._verify_output(output, expected_shape) + + def test_forward_decode(self): + """Test the decode operation with cached tokens.""" + # For decode, we only have one token per sequence + decode_len = 1 + curr_seq_len = self.seq_len + decode_len + + # Create test inputs + q, k, v = self._create_qkv_tensors(self.batch_size * decode_len) + + # Create attention layer + layer = self._create_attention_layer() + + # Create forward batch + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, decode_len), device=self.device + ), + out_cache_loc=torch.arange( + self.batch_size * self.seq_len, + self.batch_size * curr_seq_len, + device=self.device, + ), + seq_lens_sum=self.batch_size * curr_seq_len, + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device), + attn_backend=self.backend, + ) + + # Add token pool and KV cache + forward_batch.req_to_token_pool = MockReqToTokenPool( + self.batch_size, curr_seq_len, self.device + ) + forward_batch.token_to_kv_pool = self._create_kv_pool( + self.batch_size * curr_seq_len + ) + + # Pre-fill KV cache + cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + torch.arange(self.batch_size * self.seq_len, device=self.device), + cache_k, + cache_v, + layer.k_scale, + layer.v_scale, + ) + + # Initialize forward metadata before running the attention + self.backend.init_forward_metadata(forward_batch) + + # Run forward_decode + output = self.backend.forward_decode(q, k, v, layer, forward_batch) + + # Verify output + expected_shape = (self.batch_size, self.num_heads * self.head_dim) + self._verify_output(output, expected_shape) + + def test_forward_extend_with_prefix(self): + """Test extending from cached prefix tokens.""" + # Define prefix and extend lengths + prefix_len = 2 + extend_len = 2 + total_len = prefix_len + extend_len + + # Create test inputs for the extend portion + q, k, v = self._create_qkv_tensors(self.batch_size * extend_len) + + # Create attention layer + layer = self._create_attention_layer() + + # Create forward batch + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, extend_len), device=self.device + ), + out_cache_loc=torch.arange( + self.batch_size * prefix_len, + self.batch_size * total_len, + device=self.device, + ), + seq_lens_sum=self.batch_size * total_len, + forward_mode=ForwardMode.EXTEND, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device), + extend_prefix_lens=torch.tensor( + [prefix_len] * self.batch_size, device=self.device + ), + extend_seq_lens=torch.tensor( + [extend_len] * self.batch_size, device=self.device + ), + attn_backend=self.backend, + ) + + # Add token pool and KV cache + forward_batch.req_to_token_pool = MockReqToTokenPool( + self.batch_size, total_len, self.device + ) + forward_batch.token_to_kv_pool = self._create_kv_pool( + self.batch_size * total_len + ) + + # Pre-fill the KV cache for prefix with known values + cache_k = torch.ones( + self.batch_size * prefix_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + device=self.device, + ) + cache_v = ( + torch.ones( + self.batch_size * prefix_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + device=self.device, + ) + * 2 + ) + + # Set the prefix KV cache + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + torch.arange(self.batch_size * prefix_len, device=self.device), + cache_k, + cache_v, + layer.k_scale, + layer.v_scale, + ) + + # Initialize forward metadata before running the attention + self.backend.init_forward_metadata(forward_batch) + + # Run forward_extend + output = self.backend.forward_extend(q, k, v, layer, forward_batch) + + # Verify output + expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim) + self._verify_output(output, expected_shape) + + +if __name__ == "__main__": + unittest.main()