Support deterministic inference with triton backend (#10694)
This commit is contained in:
committed by
GitHub
parent
f67d1f45bc
commit
134b4f7ec2
@@ -201,6 +201,8 @@ class Envs:
|
||||
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
|
||||
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
|
||||
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
|
||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -12,7 +12,12 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_core_count,
|
||||
get_int_env_var,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -94,7 +99,25 @@ class TritonAttnBackend(AttentionBackend):
|
||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||
)
|
||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
|
||||
|
||||
# Decide whether enable deterministic inference with batch-invariant operations
|
||||
self.enable_deterministic = (
|
||||
model_runner.server_args.enable_deterministic_inference
|
||||
)
|
||||
|
||||
# Configure deterministic inference settings
|
||||
if self.enable_deterministic:
|
||||
# Use fixed split tile size for batch invariance
|
||||
self.split_tile_size = get_int_env_var(
|
||||
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
|
||||
)
|
||||
# Set static_kv_splits to False to use deterministic logic instead
|
||||
self.static_kv_splits = False
|
||||
else:
|
||||
self.split_tile_size = (
|
||||
model_runner.server_args.triton_attention_split_tile_size
|
||||
)
|
||||
|
||||
if self.split_tile_size is not None:
|
||||
self.max_kv_splits = (
|
||||
self.max_context_len + self.split_tile_size - 1
|
||||
@@ -154,13 +177,23 @@ class TritonAttnBackend(AttentionBackend):
|
||||
num_group * num_seq == num_token
|
||||
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
||||
|
||||
if self.static_kv_splits or self.device_core_count <= 0:
|
||||
# Legacy dynamic splitting logic (non-deterministic)
|
||||
if (
|
||||
self.static_kv_splits or self.device_core_count <= 0
|
||||
) and not self.enable_deterministic:
|
||||
num_kv_splits.fill_(self.max_kv_splits)
|
||||
return
|
||||
|
||||
if self.split_tile_size is not None:
|
||||
# deterministic
|
||||
if self.split_tile_size is not None and self.enable_deterministic:
|
||||
# expand seq_lens to match num_token
|
||||
if num_group > 1:
|
||||
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
|
||||
else:
|
||||
expanded_seq_lens = seq_lens
|
||||
|
||||
num_kv_splits[:] = (
|
||||
seq_lens + self.split_tile_size - 1
|
||||
expanded_seq_lens + self.split_tile_size - 1
|
||||
) // self.split_tile_size
|
||||
return
|
||||
|
||||
|
||||
@@ -565,16 +565,8 @@ class Scheduler(
|
||||
if get_bool_env_var("SGLANG_GC_LOG"):
|
||||
configure_gc_logger()
|
||||
|
||||
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
|
||||
if (
|
||||
self.server_args.enable_deterministic_inference
|
||||
and self.server_args.attention_backend == "flashinfer"
|
||||
):
|
||||
self.truncation_align_size = get_int_env_var(
|
||||
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
||||
)
|
||||
else:
|
||||
self.truncation_align_size = None
|
||||
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
||||
self.init_deterministic_inference_config()
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
@@ -621,6 +613,23 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
self.truncation_align_size = None
|
||||
return
|
||||
|
||||
backend_sizes = {
|
||||
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
||||
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
||||
}
|
||||
env_var, default_size = backend_sizes.get(
|
||||
self.server_args.attention_backend, (None, None)
|
||||
)
|
||||
self.truncation_align_size = (
|
||||
get_int_env_var(env_var, default_size) if env_var else None
|
||||
)
|
||||
|
||||
def init_tokenizer(self):
|
||||
server_args = self.server_args
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
||||
|
||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
||||
|
||||
|
||||
# Allow external code to add more choices
|
||||
|
||||
Reference in New Issue
Block a user