diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a3d8f88eb..302907b67 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -35,6 +35,7 @@ class ForwardMetadata: window_kv_indptr: torch.Tensor window_kv_indices: torch.Tensor window_num_kv_splits: torch.Tensor + window_kv_offsets: torch.Tensor class TritonAttnBackend(AttentionBackend): @@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend): window_kv_indptr = self.window_kv_indptr window_kv_indices = None window_num_kv_splits = None + window_kv_offsets = None spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): @@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend): self.sliding_window_size is not None and self.sliding_window_size > 0 ): - window_kv_indptr, window_kv_indices, window_kv_lens = ( + window_kv_indptr, window_kv_indices, window_kv_lens, _ = ( update_sliding_window_buffer( self.window_kv_indptr, self.req_to_token, @@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend): ) if self.sliding_window_size is not None and self.sliding_window_size > 0: - window_kv_indptr, window_kv_indices, window_kv_lens = ( - update_sliding_window_buffer( - self.window_kv_indptr, - self.req_to_token, - self.sliding_window_size, - forward_batch.seq_lens, - forward_batch.req_pool_indices, - bs, - self.device, - self.token_to_kv_pool_allocator, - ) + # window_kv_offsets is used to calculate the start position in custom mask + ( + window_kv_indptr, + window_kv_indices, + window_kv_lens, + window_kv_offsets, + ) = update_sliding_window_buffer( + self.window_kv_indptr, + self.req_to_token, + self.sliding_window_size, + forward_batch.seq_lens, + forward_batch.req_pool_indices, + bs, + self.device, + self.token_to_kv_pool_allocator, ) custom_mask = spec_info.custom_mask @@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend): ) # Sliding window if self.sliding_window_size is not None and self.sliding_window_size > 0: - window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer( - self.window_kv_indptr, - self.req_to_token, - self.sliding_window_size, - forward_batch.extend_prefix_lens, - forward_batch.req_pool_indices, - bs, - self.device, - self.token_to_kv_pool_allocator, + window_kv_indptr, window_kv_indices, _, _ = ( + update_sliding_window_buffer( + self.window_kv_indptr, + self.req_to_token, + self.sliding_window_size, + forward_batch.extend_prefix_lens, + forward_batch.req_pool_indices, + bs, + self.device, + self.token_to_kv_pool_allocator, + ) ) qo_indptr = self.qo_indptr @@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend): window_kv_indptr, window_kv_indices, window_num_kv_splits, + window_kv_offsets, ) def init_cuda_graph_state( @@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend): device=self.device, ) + self.cuda_graph_window_kv_offsets = torch.zeros( + (max_bs,), + dtype=torch.int32, + device=self.device, + ) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend): window_kv_indptr = self.window_kv_indptr window_kv_indices = None window_num_kv_splits = None + window_kv_offsets = None if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend): ): window_kv_indices = self.cuda_graph_window_kv_indices window_num_kv_splits = self.cuda_graph_window_num_kv_splits - window_kv_indptr, window_kv_indices, _ = ( + window_kv_indptr, window_kv_indices, _, _ = ( update_sliding_window_buffer_cuda_graph( self.window_kv_indptr, window_kv_indices, @@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend): if self.sliding_window_size is not None and self.sliding_window_size > 0: window_kv_indices = self.cuda_graph_window_kv_indices window_num_kv_splits = self.cuda_graph_window_num_kv_splits - window_kv_indptr, window_kv_indices, _ = ( + window_kv_offsets = self.cuda_graph_window_kv_offsets + window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = ( update_sliding_window_buffer_cuda_graph( self.window_kv_indptr, window_kv_indices, self.req_to_token, self.sliding_window_size, - seq_lens, + seq_lens[:bs], req_pool_indices, bs, self.token_to_kv_pool_allocator, @@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend): window_kv_indptr, window_kv_indices, window_num_kv_splits, + window_kv_offsets, ) def init_forward_metadata_replay_cuda_graph( @@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend): ): window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_kv_indices = self.cuda_graph_window_kv_indices - _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( + _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph( self.window_kv_indptr, window_kv_indices, self.req_to_token, @@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend): if self.sliding_window_size is not None and self.sliding_window_size > 0: window_num_kv_splits = self.cuda_graph_window_num_kv_splits window_kv_indices = self.cuda_graph_window_kv_indices - _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph( - self.window_kv_indptr, - window_kv_indices, - self.req_to_token, - self.sliding_window_size, - seq_lens, - req_pool_indices, - bs, - self.token_to_kv_pool_allocator, + window_kv_offsets = self.cuda_graph_window_kv_offsets + _, _, window_kv_lens, window_kv_offsets[:bs] = ( + update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens[:bs], + req_pool_indices, + bs, + self.token_to_kv_pool_allocator, + ) ) custom_mask = self.cuda_graph_custom_mask custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask @@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend): ) # Needed for sliding window mask kv_indptr = self.forward_metadata.window_kv_indptr kv_indices = self.forward_metadata.window_kv_indices + window_kv_offsets = self.forward_metadata.window_kv_offsets else: sliding_window_size = -1 kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices + window_kv_offsets = None self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), @@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend): layer.logit_cap, sliding_window_size=sliding_window_size, sinks=sinks, + window_kv_offsets=window_kv_offsets, ) return o @@ -1011,7 +1035,7 @@ def update_sliding_window_buffer( window_kv_indices[:kv_last_index] ) ) - return window_kv_indptr, window_kv_indices, window_kv_lens + return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx def update_sliding_window_buffer_cuda_graph( @@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph( window_kv_indices[:kv_last_index] ) ) - return window_kv_indptr, window_kv_indices, window_kv_lens + return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 014eadab7..d8259be20 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -190,7 +190,7 @@ def _decode_att_m_fwd( Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] - batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + batch, head_num = q.shape[0], q.shape[1] grid = (batch, head_num, MAX_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] @@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + batch, head_num = q.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = 16 diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 8b459861d..b39f1a305 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -52,6 +52,7 @@ def _fwd_kernel( mask_ptr, mask_indptr, sink_ptr, + window_kv_offset_ptr, sm_scale, kv_group_num, stride_qbs, @@ -95,6 +96,11 @@ def _fwd_kernel( if USE_CUSTOM_MASK: cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) + # For SWA, we should only load the mask in the sliding window + window_kv_offset = 0 + if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0: + window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq) + offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) offs_m = tl.arange(0, BLOCK_M) @@ -139,7 +145,9 @@ def _fwd_kernel( custom_mask = tl.load( mask_ptr + cur_seq_mask_start_idx - + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + (cur_block_m * BLOCK_M + offs_m[:, None]) + * (cur_seq_len + window_kv_offset) + + window_kv_offset + start_n + offs_n[None, :], mask=(mask_m[:, None] & mask_n[None, :]), @@ -236,7 +244,9 @@ def _fwd_kernel( custom_mask = tl.load( mask_ptr + cur_seq_mask_start_idx - + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + (cur_block_m * BLOCK_M + offs_m[:, None]) + * (cur_seq_len + window_kv_offset) + + window_kv_offset + cur_seq_len_prefix + start_n + offs_n[None, :], @@ -362,6 +372,7 @@ def extend_attention_fwd( skip_prefix_custom_mask=True, sliding_window_size=-1, sinks=None, + window_kv_offsets=None, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -449,6 +460,7 @@ def extend_attention_fwd( custom_mask, mask_indptr, sinks, + window_kv_offsets, sm_scale, kv_group_num, q_extend.stride(0),