SWA Prefix Cache (#7367)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Hanming Lu
2025-07-13 12:31:07 -07:00
committed by GitHub
parent 0c55cbcfc5
commit 9379da77de
16 changed files with 1742 additions and 158 deletions

View File

@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp = seq_lens_sum
kv_start_idx_tmp = None
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[wrapper_id],
kv_start_idx_tmp,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
def update_cross_attention(
@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
if spec_info is None:
bs = len(req_pool_indices)
@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
wrapper.begin_forward(
kv_indptr,
kv_indices,
@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
# Dispatch the update function
@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill:
paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
self.prefill_wrapper_ragged,
@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
def update_cross_attention(
@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
bs = len(seq_lens)
if spec_info is None:
@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill:
q_data_type=self.q_data_type,
)
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
# cached part
wrapper_paged.begin_forward(
qo_indptr,