From 26c0f13126867f6fa8f14e872ba63c12171c51cf Mon Sep 17 00:00:00 2001 From: Stefan He Date: Thu, 27 Mar 2025 22:07:14 -0700 Subject: [PATCH] Support Page Size > 1 for FA3 (#4832) Co-authored-by: Qingquan Song Co-authored-by: Baizhou Zhang --- .../attention/flashattention_backend.py | 68 ++++++++++++++----- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index c470f64a0..365a0a54f 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -57,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend): self.device = model_runner.device self.decode_cuda_graph_metadata = {} self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.page_size = model_runner.page_size def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" @@ -78,6 +79,17 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + + # Precompute strided indices + # [0, page_size, 2 * page_size, ...] + if self.page_size > 1: + self.strided_indices = torch.arange( + 0, metadata.page_table.shape[1], self.page_size, device=self.device + ) + metadata.page_table = ( + metadata.page_table[:, self.strided_indices] // self.page_size + ) + if forward_batch.forward_mode == ForwardMode.DECODE: # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( @@ -132,11 +144,21 @@ class FlashAttentionBackend(AttentionBackend): ) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) key_cache, value_cache = kv_cache[0], kv_cache[1] + + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + page_table = metadata.page_table + o = flash_attn_with_kvcache( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - k_cache=key_cache.unsqueeze(1), - v_cache=value_cache.unsqueeze(1), - page_table=metadata.page_table, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, cache_seqlens=metadata.cache_seqlens_int32, cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k_new=metadata.cu_seqlens_k, @@ -175,13 +197,11 @@ class FlashAttentionBackend(AttentionBackend): # Get KV cache kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) key_cache, value_cache = kv_cache[0], kv_cache[1] - # Use precomputed metadata metadata = self.forward_metadata # Pre-reshape query tensor q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # here is two side inclusive @@ -191,11 +211,20 @@ class FlashAttentionBackend(AttentionBackend): else (-1, -1) ) # Run attention with precomputed values + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + page_table = metadata.page_table + o = flash_attn_with_kvcache( q=q_reshaped, - k_cache=key_cache.unsqueeze(1), - v_cache=value_cache.unsqueeze(1), - page_table=metadata.page_table, + k_cache=key_cache, + v_cache=value_cache, + page_table=page_table, cache_seqlens=metadata.cache_seqlens_int32, cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k_new=metadata.cu_seqlens_k, @@ -207,7 +236,6 @@ class FlashAttentionBackend(AttentionBackend): k_descale=layer.k_scale, v_descale=layer.v_scale, ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) def init_cuda_graph_state(self, max_bs: int): @@ -223,7 +251,13 @@ class FlashAttentionBackend(AttentionBackend): self.decode_cuda_graph_metadata = { # Page table for token mapping (batch_size, max_context_len) "page_table": torch.zeros( - max_bs, self.max_context_len, dtype=torch.int32, device=self.device + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device ), } @@ -252,6 +286,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ req_pool_indices, : ] + if forward_mode == ForwardMode.DECODE: # Precompute cumulative sequence lengths metadata.cu_seqlens_q = torch.arange( @@ -287,14 +322,11 @@ class FlashAttentionBackend(AttentionBackend): torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) - # Only zero out the part out of max_len_k - metadata.page_table[:, metadata.max_seq_len_k :].fill_(0) - # Then do the copy - metadata.page_table[:, : metadata.max_seq_len_k].copy_( - self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k] - ) - - self.forward_decode_metadata = metadata + metadata.page_table = self.req_to_token[ + :, self.decode_cuda_graph_metadata["strided_indices"] + ] + metadata.page_table = metadata.page_table[req_pool_indices[:bs]] + self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph."""