From 1b9175cb23004e3a40dfb97ab80e7e45032c5359 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Thu, 27 Mar 2025 00:45:11 -0700 Subject: [PATCH] [FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745) Co-authored-by: Yubo Wang --- .../attention/flashattention_backend.py | 26 ++++++++++++------- python/sglang/srt/managers/schedule_batch.py | 1 + 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index d773cbf59..c470f64a0 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -29,11 +29,11 @@ class FlashAttentionMetadata: cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None + max_seq_len_q: int = 0 max_seq_len_k: int = 0 window_size: tuple = (-1, -1) page_table: torch.Tensor = None cache_seqlens_int32: torch.Tensor = None - max_seq_len_q: int = 0 class FlashAttentionBackend(AttentionBackend): @@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend): # Create metadata based on forward mode metadata = FlashAttentionMetadata() - extend_seq_lens = forward_batch.extend_seq_lens # Get sequence information seqlens_in_batch = forward_batch.seq_lens # Precompute int32 version of sequence lengths @@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend): 0, batch_size + 1, dtype=torch.int32, device=device ) else: - extend_no_prefix = not any(forward_batch.extend_prefix_lens) # Precompute cumulative sequence lengths - if not extend_no_prefix: + if any(forward_batch.extend_prefix_lens_cpu): + extend_seq_lens = forward_batch.extend_seq_lens metadata.cu_seqlens_q = torch.nn.functional.pad( torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) ) + metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) else: metadata.cu_seqlens_q = metadata.cu_seqlens_k - metadata.max_seq_len_q = seqlens_in_batch.max().item() + metadata.max_seq_len_q = metadata.max_seq_len_k self.forward_metadata = metadata def forward_extend( @@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend): seq_lens_cpu: Optional[torch.Tensor], ): # """Initialize forward metadata for replaying CUDA graph.""" - seqlens_in_batch = seq_lens[:bs] metadata = self.decode_cuda_graph_metadata[bs] - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + + # For CPU operations + max_len = seq_lens_cpu[:bs].max().item() + metadata.max_seq_len_k = max_len + + # For GPU operations + seq_lens_in_batch = seq_lens[:bs] + metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) - # Precompute maximum sequence length - metadata.max_seq_len_k = seqlens_in_batch.max().item() + # Only zero out the part out of max_len_k metadata.page_table[:, metadata.max_seq_len_k :].fill_(0) # Then do the copy metadata.page_table[:, : metadata.max_seq_len_k].copy_( self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k] ) + self.forward_decode_metadata = metadata def get_cuda_graph_seq_len_fill_value(self): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b7020474b..45e1d4be2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if ( global_server_args_dict["enable_flashinfer_mla"] or global_server_args_dict["enable_flashmla"] + or global_server_args_dict["attention_backend"] == "fa3" ): decode_seq_lens = self.seq_lens.cpu() else: