Support deterministic inference with triton backend (#10694)
This commit is contained in:
committed by
GitHub
parent
f67d1f45bc
commit
134b4f7ec2
@@ -565,16 +565,8 @@ class Scheduler(
|
||||
if get_bool_env_var("SGLANG_GC_LOG"):
|
||||
configure_gc_logger()
|
||||
|
||||
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
|
||||
if (
|
||||
self.server_args.enable_deterministic_inference
|
||||
and self.server_args.attention_backend == "flashinfer"
|
||||
):
|
||||
self.truncation_align_size = get_int_env_var(
|
||||
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
||||
)
|
||||
else:
|
||||
self.truncation_align_size = None
|
||||
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
||||
self.init_deterministic_inference_config()
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
@@ -621,6 +613,23 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
self.truncation_align_size = None
|
||||
return
|
||||
|
||||
backend_sizes = {
|
||||
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
||||
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
||||
}
|
||||
env_var, default_size = backend_sizes.get(
|
||||
self.server_args.attention_backend, (None, None)
|
||||
)
|
||||
self.truncation_align_size = (
|
||||
get_int_env_var(env_var, default_size) if env_var else None
|
||||
)
|
||||
|
||||
def init_tokenizer(self):
|
||||
server_args = self.server_args
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
Reference in New Issue
Block a user