From 134b4f7ec23012a9782ae63a44040122ca778ed5 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Sun, 21 Sep 2025 18:20:40 -0700 Subject: [PATCH] Support deterministic inference with triton backend (#10694) --- python/sglang/environ.py | 2 + .../srt/layers/attention/triton_backend.py | 43 ++++++++++++++++--- python/sglang/srt/managers/scheduler.py | 29 ++++++++----- python/sglang/srt/server_args.py | 2 +- 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/python/sglang/environ.py b/python/sglang/environ.py index 12470ba9a..e28120702 100644 --- a/python/sglang/environ.py +++ b/python/sglang/environ.py @@ -201,6 +201,8 @@ class Envs: SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False) SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096) SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048) + SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) + SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) # fmt: on diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 806db8913..55b5c6e54 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -12,7 +12,12 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2 +from sglang.srt.utils import ( + get_bool_env_var, + get_device_core_count, + get_int_env_var, + next_power_of_2, +) if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -94,7 +99,25 @@ class TritonAttnBackend(AttentionBackend): "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" ) self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits - self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size + + # Decide whether enable deterministic inference with batch-invariant operations + self.enable_deterministic = ( + model_runner.server_args.enable_deterministic_inference + ) + + # Configure deterministic inference settings + if self.enable_deterministic: + # Use fixed split tile size for batch invariance + self.split_tile_size = get_int_env_var( + "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256 + ) + # Set static_kv_splits to False to use deterministic logic instead + self.static_kv_splits = False + else: + self.split_tile_size = ( + model_runner.server_args.triton_attention_split_tile_size + ) + if self.split_tile_size is not None: self.max_kv_splits = ( self.max_context_len + self.split_tile_size - 1 @@ -154,13 +177,23 @@ class TritonAttnBackend(AttentionBackend): num_group * num_seq == num_token ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!" - if self.static_kv_splits or self.device_core_count <= 0: + # Legacy dynamic splitting logic (non-deterministic) + if ( + self.static_kv_splits or self.device_core_count <= 0 + ) and not self.enable_deterministic: num_kv_splits.fill_(self.max_kv_splits) return - if self.split_tile_size is not None: + # deterministic + if self.split_tile_size is not None and self.enable_deterministic: + # expand seq_lens to match num_token + if num_group > 1: + expanded_seq_lens = seq_lens.repeat_interleave(num_group) + else: + expanded_seq_lens = seq_lens + num_kv_splits[:] = ( - seq_lens + self.split_tile_size - 1 + expanded_seq_lens + self.split_tile_size - 1 ) // self.split_tile_size return diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c83f43122..7501f6e0c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -565,16 +565,8 @@ class Scheduler( if get_bool_env_var("SGLANG_GC_LOG"): configure_gc_logger() - # Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend - if ( - self.server_args.enable_deterministic_inference - and self.server_args.attention_backend == "flashinfer" - ): - self.truncation_align_size = get_int_env_var( - "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096 - ) - else: - self.truncation_align_size = None + # Init prefill kv split size when deterministic inference is enabled with various attention backends + self.init_deterministic_inference_config() # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( @@ -621,6 +613,23 @@ class Scheduler( ] ) + def init_deterministic_inference_config(self): + """Initialize deterministic inference configuration for different attention backends.""" + if not self.server_args.enable_deterministic_inference: + self.truncation_align_size = None + return + + backend_sizes = { + "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096), + "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096), + } + env_var, default_size = backend_sizes.get( + self.server_args.attention_backend, (None, None) + ) + self.truncation_align_size = ( + get_int_env_var(env_var, default_size) if env_var else None + ) + def init_tokenizer(self): server_args = self.server_args self.is_generation = self.model_config.is_generation diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f92e3e1a6..b1e474f77 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] -DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"] +DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] # Allow external code to add more choices