Support deterministic inference with triton backend (#10694)
This commit is contained in:
committed by
GitHub
parent
f67d1f45bc
commit
134b4f7ec2
@@ -201,6 +201,8 @@ class Envs:
|
|||||||
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
|
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
|
||||||
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
|
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
|
||||||
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
|
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
|
||||||
|
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||||
|
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,12 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.radix_attention import AttentionType
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
get_device_core_count,
|
||||||
|
get_int_env_var,
|
||||||
|
next_power_of_2,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -94,7 +99,25 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||||
)
|
)
|
||||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
|
|
||||||
|
# Decide whether enable deterministic inference with batch-invariant operations
|
||||||
|
self.enable_deterministic = (
|
||||||
|
model_runner.server_args.enable_deterministic_inference
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure deterministic inference settings
|
||||||
|
if self.enable_deterministic:
|
||||||
|
# Use fixed split tile size for batch invariance
|
||||||
|
self.split_tile_size = get_int_env_var(
|
||||||
|
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
|
||||||
|
)
|
||||||
|
# Set static_kv_splits to False to use deterministic logic instead
|
||||||
|
self.static_kv_splits = False
|
||||||
|
else:
|
||||||
|
self.split_tile_size = (
|
||||||
|
model_runner.server_args.triton_attention_split_tile_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.split_tile_size is not None:
|
if self.split_tile_size is not None:
|
||||||
self.max_kv_splits = (
|
self.max_kv_splits = (
|
||||||
self.max_context_len + self.split_tile_size - 1
|
self.max_context_len + self.split_tile_size - 1
|
||||||
@@ -154,13 +177,23 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
num_group * num_seq == num_token
|
num_group * num_seq == num_token
|
||||||
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
||||||
|
|
||||||
if self.static_kv_splits or self.device_core_count <= 0:
|
# Legacy dynamic splitting logic (non-deterministic)
|
||||||
|
if (
|
||||||
|
self.static_kv_splits or self.device_core_count <= 0
|
||||||
|
) and not self.enable_deterministic:
|
||||||
num_kv_splits.fill_(self.max_kv_splits)
|
num_kv_splits.fill_(self.max_kv_splits)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.split_tile_size is not None:
|
# deterministic
|
||||||
|
if self.split_tile_size is not None and self.enable_deterministic:
|
||||||
|
# expand seq_lens to match num_token
|
||||||
|
if num_group > 1:
|
||||||
|
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
|
||||||
|
else:
|
||||||
|
expanded_seq_lens = seq_lens
|
||||||
|
|
||||||
num_kv_splits[:] = (
|
num_kv_splits[:] = (
|
||||||
seq_lens + self.split_tile_size - 1
|
expanded_seq_lens + self.split_tile_size - 1
|
||||||
) // self.split_tile_size
|
) // self.split_tile_size
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -565,16 +565,8 @@ class Scheduler(
|
|||||||
if get_bool_env_var("SGLANG_GC_LOG"):
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
||||||
configure_gc_logger()
|
configure_gc_logger()
|
||||||
|
|
||||||
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
|
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
||||||
if (
|
self.init_deterministic_inference_config()
|
||||||
self.server_args.enable_deterministic_inference
|
|
||||||
and self.server_args.attention_backend == "flashinfer"
|
|
||||||
):
|
|
||||||
self.truncation_align_size = get_int_env_var(
|
|
||||||
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.truncation_align_size = None
|
|
||||||
|
|
||||||
# Init request dispatcher
|
# Init request dispatcher
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
@@ -621,6 +613,23 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_deterministic_inference_config(self):
|
||||||
|
"""Initialize deterministic inference configuration for different attention backends."""
|
||||||
|
if not self.server_args.enable_deterministic_inference:
|
||||||
|
self.truncation_align_size = None
|
||||||
|
return
|
||||||
|
|
||||||
|
backend_sizes = {
|
||||||
|
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
||||||
|
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
||||||
|
}
|
||||||
|
env_var, default_size = backend_sizes.get(
|
||||||
|
self.server_args.attention_backend, (None, None)
|
||||||
|
)
|
||||||
|
self.truncation_align_size = (
|
||||||
|
get_int_env_var(env_var, default_size) if env_var else None
|
||||||
|
)
|
||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
|||||||
|
|
||||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||||
|
|
||||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
||||||
|
|
||||||
|
|
||||||
# Allow external code to add more choices
|
# Allow external code to add more choices
|
||||||
|
|||||||
Reference in New Issue
Block a user