SWA Prefix Cache (#7367)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||
@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
|
||||
# Dispatch the update function
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum_tmp = seq_lens_sum
|
||||
kv_start_idx_tmp = None
|
||||
|
||||
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
self.call_begin_forward(
|
||||
decode_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx_tmp,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
if spec_info is None:
|
||||
bs = len(req_pool_indices)
|
||||
@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
if use_sliding_window_kv_pool:
|
||||
kv_last_index = kv_indptr[-1]
|
||||
kv_indices[:kv_last_index] = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
|
||||
# Dispatch the update function
|
||||
@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
||||
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
||||
)
|
||||
|
||||
self.call_begin_forward(
|
||||
self.prefill_wrapper_ragged,
|
||||
@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if use_sliding_window_kv_pool:
|
||||
kv_last_index = kv_indptr[-1]
|
||||
kv_indices[:kv_last_index] = (
|
||||
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
||||
kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
|
||||
# cached part
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
|
||||
Reference in New Issue
Block a user