From 14cb544d56b06b25483c4cf9c817b657acff8604 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 15 Aug 2024 00:53:24 -0700 Subject: [PATCH] [Fix] fix flashinfer usage for window attention (#1107) --- python/sglang/srt/layers/radix_attention.py | 5 +---- .../srt/model_executor/forward_batch_info.py | 14 ++++++-------- python/sglang/srt/model_executor/model_runner.py | 11 +++++------ 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 978a5d4c0..a7474326f 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -120,12 +120,9 @@ class RadixAttention(nn.Module): # using two wrappers is unnecessary in the current PR, but are prepared for future PRs prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged - if self.sliding_window_size != -1: - prefill_wrapper_ragged = prefill_wrapper_ragged[0] + if self.sliding_window_size != -1 or self.reuse: prefill_wrapper_paged = prefill_wrapper_paged[0] else: - if isinstance(prefill_wrapper_ragged, list): - prefill_wrapper_ragged = prefill_wrapper_ragged[1] if isinstance(prefill_wrapper_paged, list): prefill_wrapper_paged = prefill_wrapper_paged[1] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 809b3329d..66479b255 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -324,9 +324,11 @@ def update_flashinfer_indices( else: kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") for wrapper_id in range(2): - if flashinfer_use_ragged: + if flashinfer_use_ragged and wrapper_id == 1: + # full attention use ragged+paged paged_kernel_lens = prefix_lens else: + # window attention use paged only paged_kernel_lens = seq_lens if wrapper_id == 0 and forward_mode == ForwardMode.DECODE: @@ -374,13 +376,9 @@ def update_flashinfer_indices( ) qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - if flashinfer_use_ragged: - model_runner.flashinfer_prefill_wrapper_ragged[ - wrapper_id - ].end_forward() - model_runner.flashinfer_prefill_wrapper_ragged[ - wrapper_id - ].begin_forward( + if flashinfer_use_ragged and wrapper_id == 1: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( qo_indptr, qo_indptr, num_qo_heads, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 675ca60d0..748069fc2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -342,15 +342,14 @@ class ModelRunner: dtype=torch.uint8, device="cuda", ) - self.flashinfer_prefill_wrapper_ragged = [] + self.flashinfer_prefill_wrapper_ragged = ( + BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + ) self.flashinfer_prefill_wrapper_paged = [] self.flashinfer_decode_wrapper = [] for i in range(2): - self.flashinfer_prefill_wrapper_ragged.append( - BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) self.flashinfer_prefill_wrapper_paged.append( BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_workspace_buffer, "NHD"