From 2dae104dca6ffc744bc21e8209f43a0595cae230 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Jun 2025 03:58:44 -0700 Subject: [PATCH] Minor cleanup of fa3 backend (#6999) --- .../attention/flashattention_backend.py | 96 +++++++++---------- .../attention/flashinfer_mla_backend.py | 31 +++--- 2 files changed, 63 insertions(+), 64 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8eebc6dc2..381ec4a1c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend): "cache_seqlens" ][:bs] metadata.cache_seqlens_int32.copy_( - (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) + (seq_lens + self.speculative_num_draft_tokens) ) metadata.max_seq_len_q = self.speculative_num_draft_tokens @@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ :bs ] - metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + metadata.cache_seqlens_int32.copy_(seq_lens) num_tokens_per_bs = num_tokens // bs metadata.max_seq_len_q = num_tokens_per_bs @@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend): if spec_info is not None: # Draft Decode if self.topk <= 1: - metadata = self.decode_cuda_graph_metadata[bs] # When topk = 1, we use the normal decode metadata - metadata.cache_seqlens_int32.copy_( - (seq_lens + (self.speculative_step_id + 1)).to(torch.int32) - ) - - metadata.max_seq_len_k = seq_lens_cpu.max().item() + ( - self.speculative_step_id + 1 - ) - metadata.cu_seqlens_k[1:].copy_( - torch.cumsum( - metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 - ) - ) - + metadata = self.decode_cuda_graph_metadata[bs] + max_len = seq_lens_cpu.max().item() + metadata.max_seq_len_k = max_len + self.speculative_step_id + 1 max_seq_pages = ( metadata.max_seq_len_k + self.page_size - 1 ) // self.page_size - page_indices = self.req_to_token[ - req_pool_indices[:, None], - self.decode_cuda_graph_metadata["strided_indices"][ - :max_seq_pages - ], - ] - page_indices //= self.page_size - metadata.page_table[:, :max_seq_pages].copy_(page_indices) + normal_decode_set_medadata( + metadata.cache_seqlens_int32, + metadata.cu_seqlens_k, + metadata.page_table, + self.req_to_token, + req_pool_indices, + self.decode_cuda_graph_metadata["strided_indices"], + max_seq_pages, + seq_lens, + self.speculative_step_id + 1, + self.page_size, + ) + else: # When top k > 1, we need two specific draft decode metadata, and then merge states # 1. The first half of metadata for prefix tokens metadata = self.draft_decode_metadata_topk_normal[bs] - metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + metadata.cache_seqlens_int32.copy_(seq_lens) # metadata.max_seq_len_q = self.topk, already set in capture metadata.max_seq_len_k = seq_lens_cpu.max().item() # metadata.cu_seqlens_q already set in capture @@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend): self.speculative_num_steps, -1 ).T.contiguous() metadata_expand.page_table[: cache_loc.shape[0]].copy_( - cache_loc[:, :decode_length].contiguous().to(torch.int32) + cache_loc[:, :decode_length] ) # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported else: @@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k = max_len normal_decode_set_medadata( - metadata, + metadata.cache_seqlens_int32, + metadata.cu_seqlens_k, + metadata.page_table, self.req_to_token, req_pool_indices, self.decode_cuda_graph_metadata["strided_indices"], max_seq_pages, seq_lens, + 0, self.page_size, ) @@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend): if self.topk <= 1: metadata = self.target_verify_metadata[bs] metadata.cache_seqlens_int32.copy_( - (seq_lens + self.speculative_num_draft_tokens).to(torch.int32) + (seq_lens + self.speculative_num_draft_tokens) ) metadata.max_seq_len_k = ( @@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend): # When topk > 1, we need two specific target verify metadata, and then merge states # 1. The first half of metadata for prefix tokens metadata = self.target_verify_metadata_topk_normal[bs] - metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + metadata.cache_seqlens_int32.copy_(seq_lens) # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture metadata.max_seq_len_k = seq_lens_cpu.max().item() # metadata.cu_seqlens_q already set in capture @@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend): metadata_expand.page_table.copy_( non_masked_page_table.gather(1, sort_order) ) - metadata_expand.cache_seqlens_int32.copy_( - mask.sum(dim=1).to(torch.int32) - ) + metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1)) metadata_expand.cu_seqlens_k[1:].copy_( torch.cumsum( metadata_expand.cache_seqlens_int32, @@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend): ) elif forward_mode.is_draft_extend(): metadata = self.draft_extend_metadata[bs] - metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) + metadata.cache_seqlens_int32.copy_(seq_lens) metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.cu_seqlens_k[1:].copy_( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) accept_length = spec_info.accept_length[:bs] - metadata.max_seq_len_q = accept_length.max().item() + metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1 metadata.cu_seqlens_q[1:].copy_( torch.cumsum(accept_length, dim=0, dtype=torch.int32) ) @@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend): req_pool_indices[:, None], self.draft_extend_metadata["strided_indices"][:max_seq_pages], ] - page_indices //= self.page_size - metadata.page_table[:, :max_seq_pages].copy_(page_indices) + metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) if encoder_lens is not None: # Only support encoder size 1 for now @@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend: assert isinstance(forward_batch.spec_info, EagleDraftInput) for i in range(self.speculative_num_steps - 1): + # TODO: incrementally update the metadata for the later steps, + # so that they do not need to recompute everything from scratch. self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, forward_batch.req_pool_indices, @@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend: ) -@torch.compile(dynamic=True, backend=get_compiler_backend()) +# @torch.compile(dynamic=True, backend=get_compiler_backend()) +# TODO: fuse these kernels +# NOTE: torch.compile makes it slower in speculative decoding def normal_decode_set_medadata( - metadata, - req_to_token, - req_pool_indices, - strided_indices, - max_seq_pages, - seq_lens, - page_size, + cache_seqlens_int32: torch.Tensor, + cu_seqlens_k: torch.Tensor, + page_table: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + strided_indices: torch.Tensor, + max_seq_pages: torch.Tensor, + seq_lens: torch.Tensor, + seq_len_delta: int, + page_size: int, ): - metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) - metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32)) + cache_seqlens_int32.copy_(seq_lens + seq_len_delta) + cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32)) page_indices = req_to_token[ req_pool_indices[:, None], strided_indices[:max_seq_pages][None, :], ] - metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size) - metadata.page_table[:, max_seq_pages:].fill_(0) + page_table[:, :max_seq_pages].copy_(page_indices // page_size) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 918184dfc..275518b6c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -920,19 +920,18 @@ def fast_mla_decode_plan( self._page_size = page_size self._sm_scale = sm_scale - with self.device as device: - try: - # Standard version with just the required arguments (no use_profiler) - self._cached_module.plan.default( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_cpu, - kv_indptr_cpu, - kv_len_arr_cpu, - num_heads, - head_dim_ckv, - causal, - ) - except Exception as e: - raise RuntimeError(f"Error in alternate MLA plan: {e}") + try: + # Standard version with just the required arguments (no use_profiler) + self._cached_module.plan.default( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + ) + except Exception as e: + raise RuntimeError(f"Error in alternate MLA plan: {e}")