From 93470a14116a60fe5dd43f0599206e8ccabdc211 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Mon, 7 Apr 2025 11:52:42 -0700 Subject: [PATCH] Refactor and Optimize FA3 Code (#5090) Co-authored-by: Qingquan Song --- .../attention/flashattention_backend.py | 242 +++++++----------- 1 file changed, 97 insertions(+), 145 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 62604fe56..45e64c45e 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1,24 +1,16 @@ from __future__ import annotations -import numpy as np - -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput - -""" -Support different attention backends. -Now there are three backends: FlashInfer, Triton and FlashAttention. -Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. -""" - from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union +import numpy as np import torch from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache @dataclass class FlashAttentionMetadata: """Metadata to be init once in the model forward pass, - each layer's forward pass can reuse the metadata.""" + each layer's forward pass can reuse the metadata. - # Cumulative sequence lengths for query - cu_seqlens_q: torch.Tensor = None - # Cumulative sequence lengths for key - cu_seqlens_k: torch.Tensor = None + For each init metadata function, we will try set up them in below order + """ + + # Sequence lengths for the forward batch + cache_seqlens_int32: torch.Tensor = None # Maximum sequence length for query max_seq_len_q: int = 0 # Maximum sequence length for key max_seq_len_k: int = 0 + # Cumulative sequence lengths for query + cu_seqlens_q: torch.Tensor = None + # Cumulative sequence lengths for key + cu_seqlens_k: torch.Tensor = None # Window size (typically used by Gemma) window_size: tuple = (-1, -1) # Page table, the index of KV Cache Tables/Blocks page_table: torch.Tensor = None - # Sequence lengths for the forward batch - cache_seqlens_int32: torch.Tensor = None @dataclass class LocalAttentionMetadata: @@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend): self, model_runner: ModelRunner, skip_prefill: bool = False, + speculative_step_id=0, topk=0, speculative_num_steps=0, - step_id=0, ): super().__init__() @@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend): ) and (not global_server_args_dict["disable_mla"]) self.skip_prefill = skip_prefill - # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding - assert ( - topk <= 1 - ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend" - - self.topk = 1 - self.step_id = step_id + self.topk = topk self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens = ( + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id # Local attention settings self.attention_chunk_size = ( @@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend): ) def init_forward_metadata(self, forward_batch: ForwardBatch): - """Initialize forward metadata to cache repetitive calculations.""" + """Initialize forward metadata hence all layers in the forward pass can reuse it.""" metadata = FlashAttentionMetadata() seqlens_in_batch = forward_batch.seq_lens batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device + if forward_batch.forward_mode.is_decode(): - # Skip Prefill or Draft Decode - # Note: Draft Decode will be ran on the Draft Worker + # Draft Decode if forward_batch.spec_info is not None: + metadata.cache_seqlens_int32 = ( + seqlens_in_batch + (self.speculative_step_id + 1) + ).to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device ) - seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) - metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 ), (1, 0), ) - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( - self.step_id + 1 - ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - cache_loc = forward_batch.out_cache_loc.view( - self.speculative_num_steps, -1 - ).T - - for idx, single_seq_len in enumerate(seq_lens_with_decode): - real_bsz_start_idx = idx - real_bsz_end_idx = idx + 1 - metadata.page_table[ - real_bsz_start_idx:real_bsz_end_idx, - (single_seq_len - (self.step_id + 1)) : single_seq_len, - ] = cache_loc[ - real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) - ] - else: # Normal Decode without Spec Decoding + else: + # Normal Decode metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) - ) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() - metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : metadata.max_seq_len_k - ] metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] elif forward_batch.forward_mode.is_target_verify(): - # Note: Target Verify will be ran on the Target Worker - draft_token_num = forward_batch.spec_info.draft_token_num metadata.cache_seqlens_int32 = ( - forward_batch.seq_lens + draft_token_num + forward_batch.seq_lens + self.speculative_num_draft_tokens ).to(torch.int32) - metadata.max_seq_len_q = draft_token_num + metadata.max_seq_len_q = self.speculative_num_draft_tokens metadata.max_seq_len_k = ( - forward_batch.seq_lens_cpu.max().item() + draft_token_num + forward_batch.seq_lens_cpu.max().item() + + self.speculative_num_draft_tokens ) metadata.cu_seqlens_q = torch.arange( 0, - batch_size * draft_token_num + 1, - draft_token_num, + batch_size * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, dtype=torch.int32, device=device, ) @@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend): ] elif forward_batch.forward_mode.is_extend_or_draft_extend(): - # Normal or Draft Extend (Both of them will be ran on the Target Worker) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) - # Precompute maximum sequence length - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() - # Precompute page table metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - # Precompute cumulative sequence lengths if ( any(forward_batch.extend_prefix_lens_cpu) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND ): extend_seq_lens = forward_batch.extend_seq_lens + metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) ) - metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) else: - metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.max_seq_len_q = metadata.max_seq_len_k + metadata.cu_seqlens_q = metadata.cu_seqlens_k # Setup local attention if enabled if ( @@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend): ) metadata.local_attn_metadata = local_metadata - # Precompute strided indices + # Convert the page table to a strided format which is needed by FA3 API if self.page_size > 1: self.strided_indices = torch.arange( 0, metadata.page_table.shape[1], self.page_size, device=self.device @@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend): v, ) - # Use precomputed metadata + # Use precomputed metadata across all layers metadata = self.forward_metadata # Calculate window size (can be moved to metadata if layer properties don't change) @@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ) -> torch.Tensor: - """Forward pass with FlashAttention using precomputed metadata.""" - # Save KV cache if needed if k is not None: assert v is not None if save_kv_cache: @@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend): v, ) - # Use precomputed metadata + # Use precomputed metadata across all layers metadata = self.forward_metadata # Calculate window size (can be moved to metadata if layer properties don't change) @@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) - page_table = metadata.page_table if not self.use_mla: # Do multi-head attention - - # Get KV cache kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache = key_cache.view( @@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend): -1, self.page_size, layer.tp_v_head_num, layer.head_dim ) - # Pre-reshape query tensor q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) o = flash_attn_with_kvcache( q=q_reshaped, k_cache=key_cache, v_cache=value_cache, - page_table=page_table, + page_table=metadata.page_table, cache_seqlens=metadata.cache_seqlens_int32, cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k_new=metadata.cu_seqlens_k, @@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend): k_cache=k_rope_cache, v_cache=c_kv_cache, qv=q_nope, - page_table=page_table, + page_table=metadata.page_table, cache_seqlens=metadata.cache_seqlens_int32, cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k_new=metadata.cu_seqlens_k, @@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend): to avoid memory allocations. """ self.decode_cuda_graph_metadata = { - # Page table for token mapping (batch_size, max_context_len) + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cu_seqlens_q": torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=self.device + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), "page_table": torch.zeros( max_bs, (self.max_context_len + self.page_size - 1) // self.page_size, @@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend): "strided_indices": torch.arange( 0, self.max_context_len, self.page_size, device=self.device ), - "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), - "cu_seqlens_q": torch.arange( - 0, max_bs + 128, dtype=torch.int32, device=self.device - ), - "cu_seqlens_k": torch.zeros( - max_bs + 128, dtype=torch.int32, device=self.device - ), } self.target_verify_metadata = { + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cu_seqlens_q": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), "page_table": torch.zeros( max_bs, (self.max_context_len + self.page_size - 1) // self.page_size, dtype=torch.int32, device=self.device, ), - "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), - "cu_seqlens_q": torch.zeros( - max_bs + 128, dtype=torch.int32, device=self.device - ), - "cu_seqlens_k": torch.zeros( - max_bs + 128, dtype=torch.int32, device=self.device - ), - "max_seqlen_q": 0, "strided_indices": torch.arange( 0, self.max_context_len, self.page_size, device=self.device ), @@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend): if forward_mode.is_decode(): if spec_info is not None: # Draft Decode - metadata.cu_seqlens_q = torch.arange( - 0, bs + 1, dtype=torch.int32, device=device - ) metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ "cache_seqlens" ][:bs] - + metadata.max_seq_len_k = seq_lens.max().item() + ( + self.speculative_step_id + 1 + ) metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ : bs + 1 ] - metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 ), (1, 0), ) - metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1) metadata.page_table = self.decode_cuda_graph_metadata[ "page_table_draft_decode" ][req_pool_indices, :] @@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend): ) self.decode_cuda_graph_metadata[bs] = metadata elif forward_mode.is_target_verify(): - draft_token_num = spec_info.draft_token_num - metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ :bs ] metadata.cache_seqlens_int32.copy_( - (seq_lens + draft_token_num).to(torch.int32) + (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) ) - metadata.max_seq_len_q = draft_token_num - metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + seq_lens.max().item() + self.speculative_num_draft_tokens + ) - metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][ - torch.arange( - 0, - bs * draft_token_num + 1, - draft_token_num, - dtype=torch.int32, - device=device, - ) + metadata.cu_seqlens_q = torch.arange( + 0, + bs * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) + + metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][ + : (bs + 1) ] - cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)] - cu_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), - ) - ) - metadata.cu_seqlens_k = cu_k + metadata.page_table = self.target_verify_metadata["page_table"][ req_pool_indices, : ] @@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend): out_cache_loc: torch.Tensor = None, ): # """Initialize forward metadata for replaying CUDA graph.""" - device = seq_lens.device seq_lens = seq_lens[:bs] - req_pool_indices = req_pool_indices[:bs] seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] if forward_mode.is_decode(): metadata = self.decode_cuda_graph_metadata[bs] if spec_info is not None: # Draft Decode - max_len = seq_lens_cpu.max().item() - metadata.max_seq_len_k = max_len + (self.step_id + 1) - metadata.cache_seqlens_int32.copy_( - (seq_lens + (self.step_id + 1)).to(torch.int32) + (seq_lens + (self.speculative_step_id + 1)).to(torch.int32) ) - metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1) - + metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) metadata.cu_seqlens_k.copy_( torch.nn.functional.pad( torch.cumsum( @@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend): elif forward_mode.is_target_verify(): metadata = self.target_verify_metadata[bs] - draft_token_num = spec_info.draft_token_num - - metadata.cu_seqlens_q.copy_( - torch.arange( - 0, - bs * draft_token_num + 1, - draft_token_num, - dtype=torch.int32, - device=device, - ) - ) metadata.cache_seqlens_int32.copy_( - (seq_lens + draft_token_num).to(torch.int32) + (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) ) - metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num + metadata.max_seq_len_k = ( + seq_lens_cpu.max().item() + self.speculative_num_draft_tokens + ) metadata.cu_seqlens_k.copy_( torch.nn.functional.pad( torch.cumsum( @@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps + # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding + assert ( + self.topk == 1 + ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend" + self.attn_backends = [] for i in range(self.speculative_num_steps): self.attn_backends.append( FlashAttentionBackend( model_runner, + speculative_step_id=i, topk=self.topk, speculative_num_steps=self.speculative_num_steps, - step_id=i, ) )