diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 175d723b2..4b81176aa 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -72,6 +72,65 @@ def get_num_kv_splits_triton( tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) +def update_sliding_window_buffer( + window_kv_indptr, + req_to_token, + sliding_window_size, + seq_lens, + req_pool_indices, + bs, + device, +): + window_kv_lens = torch.minimum( + seq_lens, + torch.tensor(sliding_window_size + 1), + ) + window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) + window_kv_indptr = window_kv_indptr[: bs + 1] + window_kv_indices = torch.empty( + window_kv_indptr[-1], dtype=torch.int32, device=device + ) + window_kv_start_idx = seq_lens - window_kv_lens + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + window_kv_lens, + window_kv_indptr, + window_kv_start_idx, + window_kv_indices, + req_to_token.stride(0), + ) + return window_kv_indptr, window_kv_indices, window_kv_lens + + +def update_sliding_window_buffer_cuda_graph( + window_kv_indptr, + window_kv_indices, + req_to_token, + sliding_window_size, + seq_lens, + req_pool_indices, + bs, +): + window_kv_lens = torch.minimum( + seq_lens, + torch.tensor(sliding_window_size + 1), + ) + window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) + window_kv_indptr = window_kv_indptr[: bs + 1] + window_kv_start_idx = seq_lens - window_kv_lens + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + window_kv_lens, + window_kv_indptr, + window_kv_start_idx, + window_kv_indices, + req_to_token.stride(0), + ) + return window_kv_indptr, window_kv_lens + + @dataclass class ForwardMetadata: attn_logits: torch.Tensor @@ -83,6 +142,10 @@ class ForwardMetadata: qo_indptr: torch.Tensor custom_mask: torch.Tensor mask_indptr: torch.Tensor + # Sliding window + window_kv_indptr: torch.Tensor + window_kv_indices: torch.Tensor + window_num_kv_splits: torch.Tensor class TritonAttnBackend(AttentionBackend): @@ -109,6 +172,13 @@ class TritonAttnBackend(AttentionBackend): max_bs = model_runner.req_to_token_pool.size + assert not ( + model_runner.sliding_window_size is not None + and model_runner.model_config.is_encoder_decoder + ), "Sliding window and cross attention are not supported together" + self.sliding_window_size = model_runner.sliding_window_size + + # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled if kv_indptr_buf is None: self.kv_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device @@ -116,6 +186,18 @@ class TritonAttnBackend(AttentionBackend): else: self.kv_indptr = kv_indptr_buf + # If sliding window is enabled, we might need two sets of buffers + # because of interleaved attention types (e.g. for Gemma3) + self.window_kv_indptr = None + if self.sliding_window_size is not None and self.sliding_window_size > 0: + if kv_indptr_buf is None: + self.window_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + # When provided a buffer, create a clone for the second buffer + self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) + self.req_to_token = model_runner.req_to_token_pool.req_to_token if not self.skip_prefill: @@ -191,6 +273,9 @@ class TritonAttnBackend(AttentionBackend): bs = forward_batch.batch_size kv_indptr = self.kv_indptr + window_kv_indptr = self.window_kv_indptr + window_kv_indices = None + window_num_kv_splits = None spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): @@ -209,6 +294,26 @@ class TritonAttnBackend(AttentionBackend): kv_indices, self.req_to_token.stride(0), ) + # Sliding window + if ( + self.sliding_window_size is not None + and self.sliding_window_size > 0 + ): + window_kv_indptr, window_kv_indices, window_kv_lens = ( + update_sliding_window_buffer( + self.window_kv_indptr, + self.req_to_token, + self.sliding_window_size, + forward_batch.seq_lens, + forward_batch.req_pool_indices, + bs, + self.device, + ) + ) + window_num_kv_splits = torch.empty( + (bs,), dtype=torch.int32, device=self.device + ) + self.get_num_kv_splits(window_num_kv_splits, window_kv_lens) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 @@ -224,7 +329,6 @@ class TritonAttnBackend(AttentionBackend): device=self.device, ) num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device) - self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens) qo_indptr = None @@ -232,6 +336,7 @@ class TritonAttnBackend(AttentionBackend): mask_indptr = None max_extend_len = None elif forward_batch.forward_mode.is_target_verify(): + # TODO: Support sliding window in spec inference bs = len(forward_batch.req_pool_indices) qo_indptr = torch.arange( 0, @@ -303,6 +408,17 @@ class TritonAttnBackend(AttentionBackend): kv_indices, self.req_to_token.stride(0), ) + # Sliding window + if self.sliding_window_size is not None and self.sliding_window_size > 0: + window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer( + self.window_kv_indptr, + self.req_to_token, + self.sliding_window_size, + forward_batch.extend_prefix_lens, + forward_batch.req_pool_indices, + bs, + self.device, + ) qo_indptr = self.qo_indptr qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) @@ -324,6 +440,9 @@ class TritonAttnBackend(AttentionBackend): qo_indptr, custom_mask, mask_indptr, + window_kv_indptr, + window_kv_indices, + window_num_kv_splits, ) def init_cuda_graph_state( @@ -358,6 +477,20 @@ class TritonAttnBackend(AttentionBackend): device=self.device, ) + if self.sliding_window_size is not None and self.sliding_window_size > 0: + if kv_indices_buf is None: + self.cuda_graph_window_kv_indices = torch.zeros( + (max_bs * self.sliding_window_size), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf) + + self.cuda_graph_window_num_kv_splits = torch.full( + (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device + ) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -369,6 +502,9 @@ class TritonAttnBackend(AttentionBackend): spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): assert encoder_lens is None, "Not supported" + window_kv_indptr = self.window_kv_indptr + window_kv_indices = None + window_num_kv_splits = None if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -385,6 +521,21 @@ class TritonAttnBackend(AttentionBackend): kv_indices, self.req_to_token.stride(0), ) + if ( + self.sliding_window_size is not None + and self.sliding_window_size > 0 + ): + window_kv_indices = self.cuda_graph_window_kv_indices + window_num_kv_splits = self.cuda_graph_window_num_kv_splits + window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens[:bs], + req_pool_indices, + bs, + ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices @@ -468,6 +619,9 @@ class TritonAttnBackend(AttentionBackend): qo_indptr, custom_mask, mask_indptr, + window_kv_indptr, + window_kv_indices, + window_num_kv_splits, ) def init_forward_metadata_replay_cuda_graph( @@ -500,11 +654,31 @@ class TritonAttnBackend(AttentionBackend): self.req_to_token.stride(0), ) num_token = bs + if ( + self.sliding_window_size is not None + and self.sliding_window_size > 0 + ): + window_num_kv_splits = self.cuda_graph_window_num_kv_splits + window_kv_indices = self.cuda_graph_window_kv_indices + _, window_kv_lens = update_sliding_window_buffer_cuda_graph( + self.window_kv_indptr, + window_kv_indices, + self.req_to_token, + self.sliding_window_size, + seq_lens[:bs], + req_pool_indices[:bs], + bs, + ) + self.get_num_kv_splits( + window_num_kv_splits[:num_token], window_kv_lens[:bs] + ) + else: kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices num_token = spec_info.kv_indptr.shape[0] - 1 self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs]) + elif forward_mode.is_target_verify(): # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr bs = len(req_pool_indices) @@ -582,6 +756,17 @@ class TritonAttnBackend(AttentionBackend): if layer.attn_type == AttentionType.ENCODER_ONLY: causal = False + if layer.sliding_window_size is not None and layer.sliding_window_size > -1: + sliding_window_size = ( + layer.sliding_window_size + ) # Needed for sliding window mask + kv_indptr = self.forward_metadata.window_kv_indptr + kv_indices = self.forward_metadata.window_kv_indices + else: + sliding_window_size = -1 + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -590,14 +775,15 @@ class TritonAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, + kv_indptr, + kv_indices, self.forward_metadata.custom_mask, causal, self.forward_metadata.mask_indptr, self.forward_metadata.max_extend_len, layer.scaling, layer.logit_cap, + sliding_window_size, ) return o @@ -625,13 +811,20 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) + if layer.sliding_window_size is not None and layer.sliding_window_size > -1: + kv_indptr = self.forward_metadata.window_kv_indptr + kv_indices = self.forward_metadata.window_kv_indices + else: + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, + kv_indptr, + kv_indices, self.forward_metadata.attn_logits, self.forward_metadata.attn_lse, self.forward_metadata.num_kv_splits, diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 35e5c21c6..67767df9b 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -65,6 +65,7 @@ def _fwd_kernel( stride_buf_kh, stride_buf_vbs, stride_buf_vh, + SLIDING_WINDOW_SIZE: tl.constexpr, logit_cap: tl.constexpr, Lq: tl.constexpr, Lv: tl.constexpr, @@ -163,6 +164,7 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + final_mask = mask_m[:, None] & mask_n[None, :] if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK: custom_mask = tl.load( mask_ptr @@ -173,10 +175,14 @@ def _fwd_kernel( mask=(mask_m[:, None] & mask_n[None, :]), other=0, ) - custom_mask &= mask_m[:, None] & mask_n[None, :] - qk = tl.where(custom_mask, qk, float("-inf")) - else: - qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + final_mask &= custom_mask + if SLIDING_WINDOW_SIZE > 0: + # Add mask where q_id <= kv_id + sliding_window_size + window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( + start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + ) + final_mask &= window_mask + qk = tl.where(final_mask, qk, float("-inf")) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -314,6 +320,7 @@ def extend_attention_fwd( sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, + sliding_window_size=-1, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -412,6 +419,7 @@ def extend_attention_fwd( k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), + SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f89e8629f..68f570b8e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1025,10 +1025,6 @@ class ModelRunner: return AiterAttnBackend(self) elif self.server_args.attention_backend == "triton": - assert self.sliding_window_size is None, ( - "Window attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 511d9c7e8..e16c5e2db 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -277,6 +277,13 @@ class Gemma3Attention(nn.Module): k = k.permute(0, 2, 1, 3) attn_output = self.attn(q, k, v, forward_batch=forward_batch) + + # Compatible with triton backend which returns [1, s, h, head_dim] + if attn_output.dim() == 4 and attn_output.shape[0] == 1: + attn_output = attn_output.squeeze(0) + attn_output = attn_output.flatten(-2, -1) + # [s, h * head_dim] + output, _ = self.o_proj(attn_output) return output diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 513930e8f..bcec75782 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -78,6 +78,7 @@ suites = { TestFile("test_triton_attention_kernels.py", 4), TestFile("test_triton_attention_backend.py", 134), TestFile("test_triton_moe_channel_fp8_kernel.py", 25), + TestFile("test_triton_sliding_window.py", 250), TestFile("test_update_weights_from_disk.py", 114), TestFile("test_update_weights_from_tensor.py", 48), TestFile("test_vertex_endpoint.py", 31), diff --git a/test/srt/test_triton_sliding_window.py b/test/srt/test_triton_sliding_window.py new file mode 100644 index 000000000..1343fcc6b --- /dev/null +++ b/test/srt/test_triton_sliding_window.py @@ -0,0 +1,132 @@ +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestSlidingWindowAttentionTriton(CustomTestCase): + """Test sliding window attention functionality with triton backend.""" + + @classmethod + def setUpClass(cls): + """Set up the test server with Gemma3 model and triton backend.""" + # Gemma3 model supports sliding window attention + cls.model = "google/gemma-3-4b-it" + cls.base_url = DEFAULT_URL_FOR_TEST + + cls.common_args = [ + "--trust-remote-code", + "--attention-backend", + "triton", + "--context-length", + "8192", + "--random-seed", + "42", + ] + + cls.short_context_prompt = "The capital of France is" + + # Test prompt longer than window size + cls.long_context_prompt = ( + """ + Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was: + """ + * 100 + ) + cls.long_context_prompt += "\nNow, summarize the story in one sentence:" + + @classmethod + def tearDownClass(cls): + pass + + def _test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + print(f"MMLU metrics with sliding window: {metrics}") + + self.assertGreaterEqual(metrics["score"], 0.64) + + def _test_short_context_generation(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": self.short_context_prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 256, + }, + }, + ) + + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertIn("paris", result["text"].lower()) + print(f"Short context generation result: {result['text']}") + + def _test_long_context_generation(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": self.long_context_prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 256, + }, + }, + ) + + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertGreater(len(result["text"].strip()), 0) + print(f"Long context generation result: {result['text'][:100]}...") + + def test_no_cuda_graph(self): + self.no_cuda_graph_process = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=self.common_args + ["--disable-cuda-graph"], + ) + + self._test_short_context_generation() + self._test_long_context_generation() + self._test_mmlu() + + kill_process_tree(self.no_cuda_graph_process.pid) + time.sleep(5) + + def test_cuda_graph(self): + self.cuda_graph_process = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=self.common_args, + ) + + self._test_short_context_generation() + self._test_long_context_generation() + self._test_mmlu() + + kill_process_tree(self.cuda_graph_process.pid) + time.sleep(5) + + +if __name__ == "__main__": + unittest.main()