From 79961afa8281f98f380d11db45c8d4b6e66a574f Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Wed, 7 May 2025 23:40:08 -0700 Subject: [PATCH] optimize pad operations in fa3 to accelarate 100+us (#6077) --- .../attention/flashattention_backend.py | 56 ++++++------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index c148ac159..8618c01f3 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( self.speculative_step_id + 1 ) - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 ) ) @@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend): # metadata.max_seq_len_q = self.topk, already set in capture metadata.max_seq_len_k = seq_lens_cpu.max().item() # metadata.cu_seqlens_q already set in capture - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 ) ) @@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k = ( seq_lens_cpu.max().item() + self.speculative_num_draft_tokens ) - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), - ) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) max_seq_pages = ( metadata.max_seq_len_k + self.page_size - 1 @@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend): # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture metadata.max_seq_len_k = seq_lens_cpu.max().item() # metadata.cu_seqlens_q already set in capture - metadata.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ), - (1, 0), - ) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) page_table = self.req_to_token[ req_pool_indices, : metadata.max_seq_len_k @@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend): metadata_expand.cache_seqlens_int32.copy_( mask.sum(dim=1).to(torch.int32) ) - metadata_expand.cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum( - metadata_expand.cache_seqlens_int32, - dim=0, - dtype=torch.int32, - ), - (1, 0), + metadata_expand.cu_seqlens_k[1:].copy_( + torch.cumsum( + metadata_expand.cache_seqlens_int32, + dim=0, + dtype=torch.int32, ) ) metadata_expand.max_seq_len_k = ( @@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend): # Only support encoder size 1 for now metadata.encoder_max_seq_len_k = encoder_lens[0] metadata.encoder_lens_int32.copy_(encoder_lens[:1]) - metadata.encoder_cu_seqlens_k.copy_( - torch.nn.functional.pad( - torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32), - (1, 0), - ) + metadata.encoder_cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32) ) metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(