diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 4e8543213..9579b19f2 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1587,8 +1587,9 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k = max_len metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) + # Optimize cumulative sequence length calculation + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(seq_lens, dim=0, dtype=torch.int32) ) max_seq_pages = (