[1/2] Support deterministic inference with flashinfer attention backend (#10645)

Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
Baizhou Zhang
2025-09-19 23:34:29 -07:00
committed by GitHub
parent 1d1ce62495
commit 8ecef73f12
10 changed files with 427 additions and 6 deletions

View File

@@ -172,6 +172,7 @@ from sglang.srt.utils import (
freeze_gc,
get_available_gpu_memory,
get_bool_env_var,
get_int_env_var,
get_zmq_socket,
is_cpu,
kill_itself_when_parent_died,
@@ -565,6 +566,17 @@ class Scheduler(
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
if (
self.server_args.enable_deterministic_inference
and self.server_args.attention_backend == "flashinfer"
):
self.truncation_align_size = get_int_env_var(
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
)
else:
self.truncation_align_size = None
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
@@ -1846,7 +1858,11 @@ class Scheduler(
continue
req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
res = adder.add_one_req(
req,
has_chunked_req=(self.chunked_req is not None),
truncation_align_size=self.truncation_align_size,
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: