From c776234b4529fc94a170b3f33a4fdb03b4e9dd5a Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 17 Apr 2025 02:07:43 -0700 Subject: [PATCH] Enable local attention during decode (#5479) --- .../attention/flashattention_backend.py | 181 +++++++++++------- 1 file changed, 113 insertions(+), 68 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index cee8ae4c8..0682e52b5 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -142,6 +142,16 @@ def make_local_attention_virtual_batches( seqlens_k_local: Key sequence lengths for local attention block_table_local: Block table for local attention """ + # Adjust attention_chunk_size based on the actual sequence length + # to avoid index out of bounds errors + max_seq_len = seq_lens_np.max() + effective_chunk_size = min(attn_chunk_size, max_seq_len) + # Make sure effective_chunk_size is divisible by page_size + effective_chunk_size = (effective_chunk_size // page_size) * page_size + if effective_chunk_size < page_size: + effective_chunk_size = page_size + attn_chunk_size = effective_chunk_size + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] @@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + + self._init_local_attn_metadata(metadata, device) else: # Normal Decode metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) @@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + + self._init_local_attn_metadata(metadata, device) elif forward_batch.forward_mode.is_target_verify(): metadata.cache_seqlens_int32 = ( forward_batch.seq_lens + self.speculative_num_draft_tokens @@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend): metadata.cu_seqlens_q = metadata.cu_seqlens_k # Setup local attention if enabled - if ( - self.attention_chunk_size is not None - and forward_batch.forward_mode == ForwardMode.EXTEND - ): - # Convert tensors to numpy for local attention processing - cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy() - seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy() - - # Adjust attention_chunk_size based on the actual sequence length - # to avoid index out of bounds errors - max_seq_len = seq_lens_np.max() - effective_chunk_size = min(self.attention_chunk_size, max_seq_len) - # Make sure effective_chunk_size is divisible by page_size - effective_chunk_size = ( - effective_chunk_size // self.page_size - ) * self.page_size - if effective_chunk_size < self.page_size: - effective_chunk_size = self.page_size - - # Create local attention metadata - ( - seqlens_q_local_np, - cu_seqlens_q_local_np, - seqlens_k_local_np, - block_table_local, - ) = make_local_attention_virtual_batches( - effective_chunk_size, - cu_seqlens_q_np, - seq_lens_np, - metadata.page_table, - self.page_size, - ) - - local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to( - device - ), - local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), - local_block_table=block_table_local, - local_max_query_len=seqlens_q_local_np.max(), - local_max_seq_len=seqlens_k_local_np.max(), - ) - metadata.local_attn_metadata = local_metadata + if forward_batch.forward_mode == ForwardMode.EXTEND: + self._init_local_attn_metadata(metadata, device) # Encoder metadata for cross attention if forward_batch.encoder_lens is not None: @@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend): # Use precomputed metadata across all layers metadata = self.forward_metadata + local_attn_metadata = getattr(metadata, "local_attn_metadata", None) + use_local_attention = ( + self.attention_chunk_size is not None and local_attn_metadata is not None + ) # Calculate window size (can be moved to metadata if layer properties don't change) # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 @@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend): -1, self.page_size, layer.tp_v_head_num, layer.head_dim ) - q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) if layer.is_cross_attention: - page_table = metadata.encoder_page_table - cache_seqlens = metadata.encoder_lens_int32 - cu_seqlens_k = metadata.encoder_cu_seqlens_k - window_size = (-1, -1) + # Always use non-chunked logic for cross-attention + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=metadata.encoder_page_table, + cache_seqlens=metadata.encoder_lens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.encoder_cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=False, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + ) + elif use_local_attention: + # Use chunked (local) attention batching for self-attention + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=local_attn_metadata.local_block_table, + cache_seqlens=local_attn_metadata.local_seqused_k, + cu_seqlens_q=local_attn_metadata.local_query_start_loc, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=local_attn_metadata.local_max_query_len, + softmax_scale=layer.scaling, + causal=True, + window_size=(-1, -1), + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + ) else: - page_table = metadata.page_table - cache_seqlens = metadata.cache_seqlens_int32 - cu_seqlens_k = metadata.cu_seqlens_k - - o = flash_attn_with_kvcache( - q=q_reshaped, - k_cache=key_cache, - v_cache=value_cache, - page_table=page_table, - cache_seqlens=cache_seqlens, - cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=layer.scaling, - causal=causal, - window_size=window_size, - softcap=layer.logit_cap, - k_descale=k_descale, - v_descale=v_descale, - ) + # Default: single-token self-attention + o = flash_attn_with_kvcache( + q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache=key_cache, + v_cache=value_cache, + page_table=metadata.page_table, + cache_seqlens=metadata.cache_seqlens_int32, + cu_seqlens_q=metadata.cu_seqlens_q, + cu_seqlens_k_new=metadata.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=layer.scaling, + causal=True, + window_size=window_size, + softcap=layer.logit_cap, + k_descale=k_descale, + v_descale=v_descale, + ) else: # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend): seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] req_pool_indices = req_pool_indices[:bs] + device = seq_lens.device + if forward_mode.is_decode_or_idle(): metadata = self.decode_cuda_graph_metadata[bs] @@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend): ] metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + + self._init_local_attn_metadata(metadata, device) else: # Normal Decode max_len = seq_lens_cpu.max().item() @@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, max_seq_pages:].fill_(0) + self._init_local_attn_metadata(metadata, device) elif forward_mode.is_target_verify(): metadata = self.target_verify_metadata[bs] metadata.cache_seqlens_int32.copy_( @@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend): """Get the fill value for sequence length in CUDA graph.""" return 0 + def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device): + """Centralized utility to initialize local_attn_metadata if chunked attention is enabled.""" + if self.attention_chunk_size is None: + metadata.local_attn_metadata = None + return + + cu_seqlens_q = metadata.cu_seqlens_q + cache_seqlens_int32 = metadata.cache_seqlens_int32 + page_table = metadata.page_table + if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None: + metadata.local_attn_metadata = None + return + + cu_seqlens_q_np = cu_seqlens_q.cpu().numpy() + seq_lens_np = cache_seqlens_int32.cpu().numpy() + ( + seqlens_q_local_np, + cu_seqlens_q_local_np, + seqlens_k_local_np, + block_table_local, + ) = make_local_attention_virtual_batches( + self.attention_chunk_size, + cu_seqlens_q_np, + seq_lens_np, + page_table, + self.page_size, + ) + local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device), + local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), + local_block_table=block_table_local.to(device), + local_max_query_len=int(seqlens_q_local_np.max()), + local_max_seq_len=int(seqlens_k_local_np.max()), + ) + metadata.local_attn_metadata = local_metadata + class FlashAttentionMultiStepBackend: