From 9d5fa68b903d295d2b39201d54905c6801f60f7f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Jun 2025 23:05:40 -0700 Subject: [PATCH] Use torch.compile to fuse flash attention decode metadata preparation (#6973) --- .../attention/flashattention_backend.py | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 08a62c0dd..8eebc6dc2 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend): ) # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported else: - metadata = self.decode_cuda_graph_metadata[bs] # Normal Decode + metadata = self.decode_cuda_graph_metadata[bs] max_len = seq_lens_cpu.max().item() + max_seq_pages = (max_len + self.page_size - 1) // self.page_size metadata.max_seq_len_k = max_len - metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) - # Optimize cumulative sequence length calculation - metadata.cu_seqlens_k[1:].copy_( - torch.cumsum(seq_lens, dim=0, dtype=torch.int32) + normal_decode_set_medadata( + metadata, + self.req_to_token, + req_pool_indices, + self.decode_cuda_graph_metadata["strided_indices"], + max_seq_pages, + seq_lens, + self.page_size, ) - max_seq_pages = ( - metadata.max_seq_len_k + self.page_size - 1 - ) // self.page_size - page_indices = self.req_to_token[ - req_pool_indices[:, None], - self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][ - None, : - ], - ] - page_indices //= self.page_size - metadata.page_table[:, :max_seq_pages].copy_(page_indices) - metadata.page_table[:, max_seq_pages:].fill_(0) - self._update_local_attn_metadata_for_replay(metadata, bs) elif forward_mode.is_target_verify(): if self.topk <= 1: @@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend: seq_lens_cpu=forward_batch.seq_lens_cpu, out_cache_loc=forward_batch.out_cache_loc, ) + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def normal_decode_set_medadata( + metadata, + req_to_token, + req_pool_indices, + strided_indices, + max_seq_pages, + seq_lens, + page_size, +): + metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32)) + page_indices = req_to_token[ + req_pool_indices[:, None], + strided_indices[:max_seq_pages][None, :], + ] + metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size) + metadata.page_table[:, max_seq_pages:].fill_(0)