diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1c6da6934..806db8913 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -94,6 +94,11 @@ class TritonAttnBackend(AttentionBackend): "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" ) self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size + if self.split_tile_size is not None: + self.max_kv_splits = ( + self.max_context_len + self.split_tile_size - 1 + ) // self.split_tile_size # Check arguments assert not ( @@ -153,6 +158,12 @@ class TritonAttnBackend(AttentionBackend): num_kv_splits.fill_(self.max_kv_splits) return + if self.split_tile_size is not None: + num_kv_splits[:] = ( + seq_lens + self.split_tile_size - 1 + ) // self.split_tile_size + return + if num_seq < 256: SCHEDULE_SEQ = 256 else: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 925f51ea1..9e1d9e0c2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -362,6 +362,7 @@ class ServerArgs: enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False triton_attention_num_kv_splits: int = 8 + triton_attention_split_tile_size: Optional[int] = None num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False @@ -2100,6 +2101,12 @@ class ServerArgs: default=ServerArgs.triton_attention_num_kv_splits, help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", ) + parser.add_argument( + "--triton-attention-split-tile-size", + type=int, + default=ServerArgs.triton_attention_split_tile_size, + help="The size of split KV tile in flash decoding Triton kernel. Used for deterministic inference.", + ) parser.add_argument( "--num-continuous-decode-steps", type=int,