diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 305e46345..08b62458f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -53,39 +53,44 @@ class FlashInferAttnBackend(AttentionBackend): device="cuda", ) - if model_runner.sliding_window_size is None: - self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - self.workspace_buffer, "NHD" - ) - self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD" - ) - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_tensor_cores=self.decode_use_tensor_cores, - ) + if model_runner.sliding_window_size is not None: + self.num_wrappers = 2 else: - # Two wrappers: one for sliding window attention and one for full attention. - # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs - self.prefill_wrapper_ragged = None - self.prefill_wrapper_paged = [] - self.decode_wrapper = [] - for _ in range(2): - self.prefill_wrapper_paged.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) - self.decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_tensor_cores=self.decode_use_tensor_cores, - ) + self.num_wrappers = 1 + + # NOTE: we do not use ragged attention when there are multiple wrappers + self.prefill_wrapper_ragged = ( + BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD") + if self.num_wrappers == 1 + else None + ) + + # Two wrappers: one for sliding window attention and one for full attention. + # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs + self.prefill_wrappers_paged = [] + self.decode_wrappers = [] + for _ in range(self.num_wrappers): + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + self.decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, ) + ) self.forward_metadata = None self.cuda_graph_metadata = {} + def _get_wrapper_idx(self, layer: nn.Module): + if self.num_wrappers == 1: + return 0 + + # TODO: make sure the idx is related to sliding window size + return layer.sliding_window_size == -1 + def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): prefix_lens = None @@ -99,7 +104,7 @@ class FlashInferAttnBackend(AttentionBackend): use_ragged = False if ( torch.sum(forward_batch.seq_lens).item() >= 4096 - and self.model_runner.sliding_window_size is None + and self.num_wrappers == 1 ): use_ragged = True @@ -119,7 +124,7 @@ class FlashInferAttnBackend(AttentionBackend): use_ragged, extend_no_prefix, total_num_tokens, - self.decode_wrapper, + self.decode_wrappers, ) def init_cuda_graph_state(self, max_bs: int): @@ -135,45 +140,30 @@ class FlashInferAttnBackend(AttentionBackend): (max_bs,), dtype=torch.int32, device="cuda" ) - if self.model_runner.sliding_window_size is not None: - self.cuda_graph_kv_indptr = [ - self.cuda_graph_kv_indptr, - self.cuda_graph_kv_indptr.clone(), - ] - self.cuda_graph_kv_indices = [ - self.cuda_graph_kv_indices, - self.cuda_graph_kv_indices.clone(), - ] + # NOTE: the buffers are always in the form of list + self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [ + self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1) + ] + self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [ + self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) + ] def init_forward_metadata_capture_cuda_graph( self, bs: int, req_pool_indices, seq_lens ): - if self.model_runner.sliding_window_size is None: - decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices, - paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], - ) - else: - decode_wrapper = [] - for i in range(2): - decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[ - :bs - ], - ) + decode_wrappers = [] + for i in range(self.num_wrappers): + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], ) + ) update_flashinfer_indices( ForwardMode.DECODE, @@ -181,12 +171,12 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices, seq_lens, None, - decode_wrapper, + decode_wrappers, ) - self.cuda_graph_metadata[bs] = decode_wrapper + self.cuda_graph_metadata[bs] = decode_wrappers - self.forward_metadata = (False, False, None, decode_wrapper) + self.forward_metadata = (False, False, None, decode_wrappers) def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices, seq_lens @@ -204,17 +194,11 @@ class FlashInferAttnBackend(AttentionBackend): return 0 def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - if not isinstance(self.prefill_wrapper_paged, list): - prefill_wrapper_paged = self.prefill_wrapper_paged - else: - if layer.sliding_window_size != -1: - prefill_wrapper_paged = self.prefill_wrapper_paged[0] - else: - prefill_wrapper_paged = self.prefill_wrapper_paged[1] + prefill_wrapper_paged = self.prefill_wrappers_paged[ + self._get_wrapper_idx(layer) + ] - use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( - self.forward_metadata - ) + use_ragged, extend_no_prefix, _, _ = self.forward_metadata if not use_ragged: if k is not None: @@ -260,15 +244,7 @@ class FlashInferAttnBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): - use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( - self.forward_metadata - ) - - if isinstance(decode_wrapper, list): - if layer.sliding_window_size != -1: - decode_wrapper = decode_wrapper[0] - else: - decode_wrapper = decode_wrapper[1] + decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)] if k is not None: assert v is not None diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 291091b10..9568226ea 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -47,7 +47,7 @@ class FlashinferUpdater: req_pool_indices, seq_lens, prefix_lens, - decode_wrapper=None, + decode_wrappers=None, use_ragged=False, ): self.forward_mode = forward_mode @@ -66,14 +66,14 @@ class FlashinferUpdater: self.head_dim = model_runner.model_config.head_dim self.batch_size = len(req_pool_indices) - self.decode_wrapper = ( - decode_wrapper or self.model_runner.attn_backend.decode_wrapper + self.decode_wrappers = ( + decode_wrappers or self.model_runner.attn_backend.decode_wrappers ) self.prefill_wrapper_ragged = ( self.model_runner.attn_backend.prefill_wrapper_ragged ) - self.prefill_wrapper_paged = ( - self.model_runner.attn_backend.prefill_wrapper_paged + self.prefill_wrappers_paged = ( + self.model_runner.attn_backend.prefill_wrappers_paged ) self.kv_last_page_len = torch.ones( @@ -142,6 +142,7 @@ class FlashinferUpdater: ) def _update_decode_indices(self, decode_wrapper): + assert not isinstance(decode_wrapper, list) decode_wrapper.end_forward() decode_wrapper.begin_forward( self.kv_indptr, @@ -156,6 +157,9 @@ class FlashinferUpdater: ) def _update_extend_indices(self, ragged_wrapper, paged_wrapper): + assert not isinstance(paged_wrapper, list) + assert not isinstance(ragged_wrapper, list) + # extend part qo_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" @@ -189,11 +193,11 @@ class FlashinferUpdater: self._init_indices_no_sliding_window() if self.forward_mode.is_decode(): - self._update_decode_indices(self.decode_wrapper) + self._update_decode_indices(self.decode_wrappers[0]) else: self._update_extend_indices( self.prefill_wrapper_ragged, - self.prefill_wrapper_paged, + self.prefill_wrappers_paged[0], ) def update_indices_sliding_window(self): @@ -202,11 +206,11 @@ class FlashinferUpdater: for wrapper_id in range(2): self._init_indices_sliding_window(wrapper_id) if self.forward_mode.is_decode(): - self._update_decode_indices(self.decode_wrapper[wrapper_id]) + self._update_decode_indices(self.decode_wrappers[wrapper_id]) else: self._update_extend_indices( None, - self.prefill_wrapper_paged[wrapper_id], + self.prefill_wrappers_paged[wrapper_id], ) @@ -216,7 +220,7 @@ def update_flashinfer_indices( req_pool_indices, seq_lens, prefix_lens, - decode_wrapper=None, + decode_wrappers=None, use_ragged=False, ): updater = FlashinferUpdater( @@ -225,7 +229,7 @@ def update_flashinfer_indices( req_pool_indices, seq_lens, prefix_lens, - decode_wrapper, + decode_wrappers, use_ragged, )