diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 1e220b9e2..051603a87 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache @dataclass class FlashAttentionMetadata: - """Metadata for decode operations to avoid redundant computations.""" + """Metadata to be init once in the model forward pass, + each layer's forward pass can reuse the metadata.""" + # Cumulative sequence lengths for query cu_seqlens_q: torch.Tensor = None + # Cumulative sequence lengths for key cu_seqlens_k: torch.Tensor = None + # Maximum sequence length for query max_seq_len_q: int = 0 + # Maximum sequence length for key max_seq_len_k: int = 0 + # Window size (typically used by Gemma) window_size: tuple = (-1, -1) + # Page table, the index of KV Cache Tables/Blocks page_table: torch.Tensor = None + # Sequence lengths for the forward batch cache_seqlens_int32: torch.Tensor = None class FlashAttentionBackend(AttentionBackend): - """FlashAttention backend implementation.""" + """FlashAttention backend implementation. + + Note about the init: + - If no spec decoding + - FlashAttentionBackend will be init once when the server starts. + - If spec decoding + - FlashAttentionBackend will be init once for the target worker + - FlashAttentionMultiStepBackend will be once for the draft worker + - It will spawn num_steps FlashAttentionBackend for the draft worker + + Note about CUDA Graph: + - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify. + - We don't support CUDA Graph for Extend and Draft Extend. + - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called. + - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph. + """ def __init__( self, @@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend): and model_runner.model_config.is_encoder_decoder ), "Sliding window and cross attention are not supported together" - # Initialize metadata self.forward_metadata: FlashAttentionMetadata = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.decode_cuda_graph_metadata = {} + self.target_verify_metadata = {} self.req_to_token = model_runner.req_to_token_pool.req_to_token self.page_size = model_runner.page_size self.use_mla = ( model_runner.model_config.attention_arch == AttentionArch.MLA ) and (not global_server_args_dict["disable_mla"]) self.skip_prefill = skip_prefill - self.topk = topk - self.speculative_num_steps = speculative_num_steps + + # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding + assert ( + topk <= 1 + ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend" + + self.topk = 1 self.step_id = step_id + self.speculative_num_steps = speculative_num_steps def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize forward metadata to cache repetitive calculations.""" - # Create metadata based on forward mode metadata = FlashAttentionMetadata() - - # Get sequence information seqlens_in_batch = forward_batch.seq_lens - # Precompute int32 version of sequence lengths batch_size = len(seqlens_in_batch) device = seqlens_in_batch.device - - if forward_batch.forward_mode == ForwardMode.DECODE: - if self.skip_prefill: + if forward_batch.forward_mode.is_decode(): + # Skip Prefill or Draft Decode + # Note: Draft Decode will be ran on the Draft Worker + if forward_batch.spec_info is not None: metadata.cu_seqlens_q = torch.arange( - 0, batch_size * self.topk + 1, dtype=torch.int32, device=device + 0, batch_size + 1, dtype=torch.int32, device=device ) seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1) - metadata.cache_seqlens_int32 = ( - (seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32) - ) + metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum( metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 @@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - metadata.page_table = metadata.page_table.repeat_interleave( - self.topk, dim=0 - ) cache_loc = forward_batch.out_cache_loc.view( self.speculative_num_steps, -1 ).T - # Calculate page table indices and cache location indices to update the page table. - batch_indices = torch.arange( - batch_size, device=device - ).repeat_interleave(self.topk * (self.step_id + 1)) - topk_indices = torch.arange(self.topk, device=device).repeat( - batch_size * (self.step_id + 1) - ) - row_indices = batch_indices * self.topk + topk_indices - page_table_col_base_indices = seqlens_in_batch.unsqueeze( - 1 - ) + torch.arange(self.step_id + 1, device=device) - page_table_col_indices = page_table_col_base_indices.view(-1).repeat( - self.topk - ) - - cache_loc_col_indices = torch.arange( - self.step_id + 1, device=device, dtype=torch.int32 - ).repeat(batch_size * self.topk) - - metadata.page_table[row_indices, page_table_col_indices] = cache_loc[ - row_indices, cache_loc_col_indices - ].to(torch.int32) - else: + for idx, single_seq_len in enumerate(seq_lens_with_decode): + real_bsz_start_idx = idx + real_bsz_end_idx = idx + 1 + metadata.page_table[ + real_bsz_start_idx:real_bsz_end_idx, + (single_seq_len - (self.step_id + 1)) : single_seq_len, + ] = cache_loc[ + real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1) + ] + else: # Normal Decode without Spec Decoding metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) - # Precompute maximum sequence length metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() - # Precompute page table metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] metadata.cu_seqlens_q = torch.arange( 0, batch_size + 1, dtype=torch.int32, device=device ) - elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY: + elif forward_batch.forward_mode.is_target_verify(): + # Note: Target Verify will be ran on the Target Worker draft_token_num = forward_batch.spec_info.draft_token_num - - metadata.cu_seqlens_q = torch.arange( - 0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device + metadata.cache_seqlens_int32 = ( + forward_batch.seq_lens + draft_token_num + ).to(torch.int32) + metadata.max_seq_len_q = draft_token_num + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + draft_token_num ) - - aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32) - metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave( - forward_batch.spec_info.draft_token_num + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * draft_token_num + 1, + draft_token_num, + dtype=torch.int32, + device=device, ) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0), ) - metadata.max_seq_len_k = ( - forward_batch.seq_lens_cpu.max().item() + draft_token_num - ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k - ].repeat_interleave(draft_token_num, dim=0) - aug_cum_len = torch.nn.functional.pad( - torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0) - ) - for idx, single_seq_len in enumerate(aug_seq_lens): - metadata.page_table[ - idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len - ] *= forward_batch.spec_info.custom_mask[ - aug_cum_len[idx] - * draft_token_num : aug_cum_len[idx + 1] - * draft_token_num - ].view( - draft_token_num, -1 - ) + ] - metadata.max_seq_len_q = 1 - else: + elif forward_batch.forward_mode.is_extend_or_draft_extend(): + # Normal or Draft Extend (Both of them will be ran on the Target Worker) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) @@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_q = metadata.max_seq_len_k # Precompute strided indices - # [0, page_size, 2 * page_size, ...] if self.page_size > 1: self.strided_indices = torch.arange( 0, metadata.page_table.shape[1], self.page_size, device=self.device @@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - if k is not None: assert v is not None if save_kv_cache: @@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend): page_table = metadata.page_table - # # Use Flash Attention for prefill + # Use Flash Attention for prefill if not self.use_mla: # Do multi-head attention kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) @@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) - page_table = metadata.page_table if not self.use_mla: @@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend): k_descale=layer.k_scale, v_descale=layer.v_scale, ) - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) def init_cuda_graph_state(self, max_bs: int): @@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend): This creates fixed-size tensors that will be reused during CUDA graph replay to avoid memory allocations. """ - if self.speculative_num_steps > 0: - raise NotImplementedError( - "FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!" - ) - self.decode_cuda_graph_metadata = { # Page table for token mapping (batch_size, max_context_len) "page_table": torch.zeros( @@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, device=self.device, ), + "page_table_draft_decode": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cu_seqlens_q": torch.arange( + 0, max_bs + 128, dtype=torch.int32, device=self.device + ), + "cu_seqlens_k": torch.zeros( + max_bs + 128, dtype=torch.int32, device=self.device + ), + } + + self.target_verify_metadata = { + "page_table": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cu_seqlens_q": torch.zeros( + max_bs + 128, dtype=torch.int32, device=self.device + ), + "cu_seqlens_k": torch.zeros( + max_bs + 128, dtype=torch.int32, device=self.device + ), + "max_seqlen_q": 0, "strided_indices": torch.arange( 0, self.max_context_len, self.page_size, device=self.device ), @@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend): ): """Initialize forward metadata for capturing CUDA graph.""" metadata = FlashAttentionMetadata() - # Get sequence information - metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) - batch_size = len(seq_lens) device = seq_lens.device - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) - ) - # Precompute maximum sequence length - metadata.max_seq_len_k = seq_lens.max().item() - # Precompute page table - metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ - req_pool_indices, : - ] - if forward_mode.is_cuda_graph(): - # Precompute cumulative sequence lengths - metadata.cu_seqlens_q = torch.arange( - 0, batch_size + 1, dtype=torch.int32, device=device + if forward_mode.is_decode(): + if spec_info is not None: + # Draft Decode + metadata.cu_seqlens_q = torch.arange( + 0, bs + 1, dtype=torch.int32, device=device + ) + metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ + "cache_seqlens" + ][:bs] + + metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ + : bs + 1 + ] + + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1) + metadata.page_table = self.decode_cuda_graph_metadata[ + "page_table_draft_decode" + ][req_pool_indices, :] + else: + # Normal Decode + # Get sequence information + metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + batch_size = len(seq_lens) + device = seq_lens.device + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + # Precompute maximum sequence length + metadata.max_seq_len_k = seq_lens.max().item() + # Precompute page table + metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ + req_pool_indices, : + ] + # Precompute cumulative sequence lengths + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + self.decode_cuda_graph_metadata[bs] = metadata + elif forward_mode.is_target_verify(): + draft_token_num = spec_info.draft_token_num + + metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ + :bs + ] + metadata.cache_seqlens_int32.copy_( + (seq_lens + draft_token_num).to(torch.int32) ) - else: - raise ValueError("Do not support Prefill Mode cuda graph") - self.decode_cuda_graph_metadata[bs] = metadata + + metadata.max_seq_len_q = draft_token_num + metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num + + metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][ + torch.arange( + 0, + bs * draft_token_num + 1, + draft_token_num, + dtype=torch.int32, + device=device, + ) + ] + cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)] + cu_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) + metadata.cu_seqlens_k = cu_k + metadata.page_table = self.target_verify_metadata["page_table"][ + req_pool_indices, : + ] + + self.target_verify_metadata[bs] = metadata + self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: torch.Tensor = None, ): # """Initialize forward metadata for replaying CUDA graph.""" - metadata = self.decode_cuda_graph_metadata[bs] + device = seq_lens.device + seq_lens = seq_lens[:bs] + req_pool_indices = req_pool_indices[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + if forward_mode.is_decode(): + metadata = self.decode_cuda_graph_metadata[bs] - # For CPU operations - max_len = seq_lens_cpu[:bs].max().item() - metadata.max_seq_len_k = max_len + if spec_info is not None: + # Draft Decode + max_len = seq_lens_cpu.max().item() + metadata.max_seq_len_k = max_len + (self.step_id + 1) - # For GPU operations - seq_lens_in_batch = seq_lens[:bs] - metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32) - metadata.cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) - ) + metadata.cache_seqlens_int32.copy_( + (seq_lens + (self.step_id + 1)).to(torch.int32) + ) + + metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1) + + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) + + page_table = self.req_to_token[ + req_pool_indices, : metadata.max_seq_len_k + ] + + metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) + else: + # Normal Decode + max_len = seq_lens_cpu.max().item() + metadata.max_seq_len_k = max_len + + metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + :, + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], + ] + page_indices = page_indices[req_pool_indices] // self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) + metadata.page_table[:, max_seq_pages:].fill_(0) + + elif forward_mode.is_target_verify(): + metadata = self.target_verify_metadata[bs] + draft_token_num = spec_info.draft_token_num + + metadata.cu_seqlens_q.copy_( + torch.arange( + 0, + bs * draft_token_num + 1, + draft_token_num, + dtype=torch.int32, + device=device, + ) + ) + metadata.cache_seqlens_int32.copy_( + (seq_lens + draft_token_num).to(torch.int32) + ) + + metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num + metadata.cu_seqlens_k.copy_( + torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + ) + page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k] + metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) - max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size - page_indices = self.req_to_token[ - :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages] - ] - page_indices = page_indices[req_pool_indices[:bs]] // self.page_size - metadata.page_table[:, :max_seq_pages].copy_(page_indices) - metadata.page_table[:, max_seq_pages:].fill_(0) self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self): @@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend: self.attn_backends.append( FlashAttentionBackend( model_runner, - skip_prefill=True, topk=self.topk, speculative_num_steps=self.speculative_num_steps, step_id=i, @@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend: for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state(max_bs) - def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def init_forward_metadata_capture_cuda_graph( + self, + forward_batch: ForwardBatch, + ): assert forward_batch.spec_info is not None assert isinstance(forward_batch.spec_info, EagleDraftInput) @@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend: forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, seq_lens_cpu=forward_batch.seq_lens_cpu, + out_cache_loc=forward_batch.out_cache_loc, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index daed3dc29..96a13a999 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -104,6 +104,9 @@ class ForwardMode(IntEnum): or self == ForwardMode.IDLE ) + def is_extend_or_draft_extend(self): + return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND + def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST