[1/2] Support deterministic inference with flashinfer attention backend (#10645)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -197,6 +197,11 @@ class Envs:
|
|||||||
SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False)
|
SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False)
|
||||||
SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False)
|
SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False)
|
||||||
|
|
||||||
|
# Deterministic inference
|
||||||
|
SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False)
|
||||||
|
SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096)
|
||||||
|
SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048)
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
get_int_env_var,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
@@ -40,6 +41,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
@@ -123,12 +125,33 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||||
|
|
||||||
|
# When deterministic inference is enabled, tensor cores should be used for decode
|
||||||
|
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
|
||||||
|
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
|
||||||
|
self.enable_deterministic = (
|
||||||
|
model_runner.server_args.enable_deterministic_inference
|
||||||
|
)
|
||||||
|
self.prefill_split_tile_size = None
|
||||||
|
self.decode_split_tile_size = None
|
||||||
|
self.disable_cuda_graph_kv_split = False
|
||||||
|
if self.enable_deterministic:
|
||||||
|
self.decode_use_tensor_cores = True
|
||||||
|
self.prefill_split_tile_size = get_int_env_var(
|
||||||
|
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
||||||
|
)
|
||||||
|
self.decode_split_tile_size = get_int_env_var(
|
||||||
|
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
|
||||||
|
)
|
||||||
|
self.disable_cuda_graph_kv_split = True
|
||||||
|
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
|
||||||
|
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_workspace_buffer is None:
|
||||||
# different from flashinfer zero_init_global_workspace_buffer
|
# different from flashinfer zero_init_global_workspace_buffer
|
||||||
|
global_workspace_size = global_config.flashinfer_workspace_size
|
||||||
global_workspace_buffer = torch.empty(
|
global_workspace_buffer = torch.empty(
|
||||||
global_config.flashinfer_workspace_size,
|
global_workspace_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
@@ -219,6 +242,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
decode_wrappers=self.decode_wrappers,
|
decode_wrappers=self.decode_wrappers,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
|
fixed_split_size=self.decode_split_tile_size,
|
||||||
|
disable_split_kv=False,
|
||||||
)
|
)
|
||||||
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
||||||
elif forward_batch.forward_mode.is_draft_extend():
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
@@ -258,7 +283,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged = False
|
use_ragged = False
|
||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
else:
|
else:
|
||||||
use_ragged = True
|
use_ragged = not self.enable_deterministic
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
@@ -271,6 +296,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged=use_ragged,
|
use_ragged=use_ragged,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
spec_info=None,
|
spec_info=None,
|
||||||
|
fixed_split_size=self.prefill_split_tile_size,
|
||||||
)
|
)
|
||||||
self.forward_metadata = PrefillMetadata(
|
self.forward_metadata = PrefillMetadata(
|
||||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||||
@@ -347,6 +373,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
decode_wrappers=decode_wrappers,
|
decode_wrappers=decode_wrappers,
|
||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
|
fixed_split_size=None,
|
||||||
|
disable_split_kv=self.disable_cuda_graph_kv_split,
|
||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||||
@@ -439,6 +467,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
|
fixed_split_size=None,
|
||||||
|
disable_split_kv=self.disable_cuda_graph_kv_split,
|
||||||
)
|
)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
@@ -646,6 +676,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -661,6 +693,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
@@ -672,6 +706,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
None,
|
None,
|
||||||
spec_info,
|
spec_info,
|
||||||
seq_lens_cpu,
|
seq_lens_cpu,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
|
disable_split_kv=disable_split_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
@@ -685,6 +721,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
assert self.sliding_window_size is not None
|
assert self.sliding_window_size is not None
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -735,6 +773,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -771,6 +811,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
],
|
],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
@@ -825,6 +867,10 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
data_type=self.data_type,
|
data_type=self.data_type,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
|
disable_split_kv=(
|
||||||
|
disable_split_kv if disable_split_kv is not None else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if locally_override:
|
if locally_override:
|
||||||
@@ -876,6 +922,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -893,6 +940,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||||
@@ -916,6 +964,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.qo_indptr[0],
|
self.qo_indptr[0],
|
||||||
use_ragged,
|
use_ragged,
|
||||||
spec_info,
|
spec_info,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
@@ -931,6 +980,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -979,6 +1029,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[
|
spec_info: Optional[
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -1024,6 +1075,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
],
|
],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
bs = len(seq_lens)
|
bs = len(seq_lens)
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
@@ -1094,6 +1146,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_data_type=self.data_type,
|
kv_data_type=self.data_type,
|
||||||
custom_mask=custom_mask,
|
custom_mask=custom_mask,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1327,6 +1380,8 @@ def fast_decode_plan(
|
|||||||
rope_scale: Optional[float] = None,
|
rope_scale: Optional[float] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
non_blocking: bool = True,
|
non_blocking: bool = True,
|
||||||
|
fixed_split_size: Optional[int] = None,
|
||||||
|
disable_split_kv: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||||
@@ -1352,6 +1407,9 @@ def fast_decode_plan(
|
|||||||
|
|
||||||
if self.use_tensor_cores:
|
if self.use_tensor_cores:
|
||||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||||
|
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
|
||||||
|
if fixed_split_size is None:
|
||||||
|
fixed_split_size = -1
|
||||||
|
|
||||||
if self.is_cuda_graph_enabled:
|
if self.is_cuda_graph_enabled:
|
||||||
if batch_size != self._fixed_batch_size:
|
if batch_size != self._fixed_batch_size:
|
||||||
@@ -1433,8 +1491,8 @@ def fast_decode_plan(
|
|||||||
head_dim,
|
head_dim,
|
||||||
False, # causal
|
False, # causal
|
||||||
window_left,
|
window_left,
|
||||||
-1,
|
fixed_split_size,
|
||||||
False,
|
disable_split_kv,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error in standard plan: {e}")
|
raise RuntimeError(f"Error in standard plan: {e}")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
"""Fused operators for normalization layers."""
|
"""Fused operators for normalization layers."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -80,6 +81,8 @@ class RMSNorm(CustomOp):
|
|||||||
)
|
)
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
self._forward_method = self.forward_aiter
|
self._forward_method = self.forward_aiter
|
||||||
|
if os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] == "1":
|
||||||
|
self._forward_method = self.forward_native
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"enable_custom_logit_processor",
|
"enable_custom_logit_processor",
|
||||||
"disaggregation_mode",
|
"disaggregation_mode",
|
||||||
|
"enable_deterministic_inference",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
|
|||||||
@@ -541,7 +541,9 @@ class PrefillAdder:
|
|||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
|
|
||||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
def add_one_req(
|
||||||
|
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
|
||||||
|
):
|
||||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||||
|
|
||||||
@@ -600,6 +602,17 @@ class PrefillAdder:
|
|||||||
if trunc_len <= 0:
|
if trunc_len <= 0:
|
||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
|
# When truncation align size is set, we want to assert that the prefill prefix length is multiple of truncation align size
|
||||||
|
# A typical use case is when deterministic inference is enabled with flashinfer attention backend,
|
||||||
|
# we need the prefill prefix length to be multiple of attention split size
|
||||||
|
if truncation_align_size is not None:
|
||||||
|
if trunc_len < truncation_align_size:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
else:
|
||||||
|
trunc_len = truncation_align_size * (
|
||||||
|
trunc_len // truncation_align_size
|
||||||
|
)
|
||||||
|
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ from sglang.srt.utils import (
|
|||||||
freeze_gc,
|
freeze_gc,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
@@ -565,6 +566,17 @@ class Scheduler(
|
|||||||
if get_bool_env_var("SGLANG_GC_LOG"):
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
||||||
configure_gc_logger()
|
configure_gc_logger()
|
||||||
|
|
||||||
|
# Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend
|
||||||
|
if (
|
||||||
|
self.server_args.enable_deterministic_inference
|
||||||
|
and self.server_args.attention_backend == "flashinfer"
|
||||||
|
):
|
||||||
|
self.truncation_align_size = get_int_env_var(
|
||||||
|
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.truncation_align_size = None
|
||||||
|
|
||||||
# Init request dispatcher
|
# Init request dispatcher
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
@@ -1846,7 +1858,11 @@ class Scheduler(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
req.init_next_round_input(self.tree_cache)
|
req.init_next_round_input(self.tree_cache)
|
||||||
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
res = adder.add_one_req(
|
||||||
|
req,
|
||||||
|
has_chunked_req=(self.chunked_req is not None),
|
||||||
|
truncation_align_size=self.truncation_align_size,
|
||||||
|
)
|
||||||
|
|
||||||
if res != AddReqResult.CONTINUE:
|
if res != AddReqResult.CONTINUE:
|
||||||
if res == AddReqResult.NO_TOKEN:
|
if res == AddReqResult.NO_TOKEN:
|
||||||
|
|||||||
@@ -406,6 +406,12 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
||||||
|
|
||||||
|
# Enable batch invariant mode
|
||||||
|
if server_args.enable_deterministic_inference:
|
||||||
|
from batch_invariant_ops import enable_batch_invariant_mode
|
||||||
|
|
||||||
|
enable_batch_invariant_mode()
|
||||||
|
|
||||||
# Init memory pool and attention backends
|
# Init memory pool and attention backends
|
||||||
self.init_memory_pool(
|
self.init_memory_pool(
|
||||||
min_per_gpu_memory,
|
min_per_gpu_memory,
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ class SamplingBatchInfo:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
global_server_args_dict = cls._get_global_server_args_dict()
|
global_server_args_dict = cls._get_global_server_args_dict()
|
||||||
|
enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
|
||||||
|
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
device = batch.device
|
device = batch.device
|
||||||
|
|||||||
@@ -118,6 +118,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
|||||||
|
|
||||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||||
|
|
||||||
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
|
||||||
|
|
||||||
|
|
||||||
# Allow external code to add more choices
|
# Allow external code to add more choices
|
||||||
def add_load_format_choices(choices):
|
def add_load_format_choices(choices):
|
||||||
@@ -437,6 +439,9 @@ class ServerArgs:
|
|||||||
max_mamba_cache_size: Optional[int] = None
|
max_mamba_cache_size: Optional[int] = None
|
||||||
mamba_ssm_dtype: str = "float32"
|
mamba_ssm_dtype: str = "float32"
|
||||||
|
|
||||||
|
# For deterministic inference
|
||||||
|
enable_deterministic_inference: bool = False
|
||||||
|
|
||||||
# Deprecated arguments
|
# Deprecated arguments
|
||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
@@ -980,6 +985,29 @@ class ServerArgs:
|
|||||||
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
|
"Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deterministic inference
|
||||||
|
os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = (
|
||||||
|
"1" if self.enable_deterministic_inference else "0"
|
||||||
|
)
|
||||||
|
if self.enable_deterministic_inference:
|
||||||
|
# Check batch_invariant_ops dependency
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
if not importlib.util.find_spec("batch_invariant_ops"):
|
||||||
|
raise ValueError(
|
||||||
|
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check some settings
|
||||||
|
self.disable_radix_cache = True
|
||||||
|
logger.warning(
|
||||||
|
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
|
||||||
|
)
|
||||||
|
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and tokenizer
|
# Model and tokenizer
|
||||||
@@ -2470,6 +2498,13 @@ class ServerArgs:
|
|||||||
help="Number of sm partition groups.",
|
help="Number of sm partition groups.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For deterministic inference
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-deterministic-inference",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable deterministic inference mode with batch invariant ops.",
|
||||||
|
)
|
||||||
|
|
||||||
# Deprecated arguments
|
# Deprecated arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-ep-moe",
|
"--enable-ep-moe",
|
||||||
|
|||||||
283
python/sglang/test/test_deterministic.py
Normal file
283
python/sglang/test/test_deterministic.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.test.test_deterministic --n-trials <numer_of_trials> --test-mode <single|mixed|prefix> --profile
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.profiler import run_profile
|
||||||
|
|
||||||
|
PROMPT_1 = "Tell me about Richard Feynman: "
|
||||||
|
PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
|
||||||
|
dirpath = os.path.dirname(__file__)
|
||||||
|
with open("python/sglang/test/long_prompt.txt", "r") as f:
|
||||||
|
LONG_PROMPT = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BenchArgs:
|
||||||
|
host: str = "localhost"
|
||||||
|
port: int = 30000
|
||||||
|
batch_size: int = 1
|
||||||
|
temperature: float = 0.0
|
||||||
|
max_new_tokens: int = 100
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
return_logprob: bool = False
|
||||||
|
stream: bool = False
|
||||||
|
profile: bool = False
|
||||||
|
profile_steps: int = 3
|
||||||
|
profile_by_stage: bool = False
|
||||||
|
test_mode: str = "single"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--host", type=str, default=BenchArgs.host)
|
||||||
|
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
||||||
|
parser.add_argument("--n-trials", type=int, default=50)
|
||||||
|
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
|
||||||
|
)
|
||||||
|
parser.add_argument("--return-logprob", action="store_true")
|
||||||
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-mode",
|
||||||
|
type=str,
|
||||||
|
default=BenchArgs.test_mode,
|
||||||
|
choices=["single", "mixed", "prefix"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
||||||
|
)
|
||||||
|
parser.add_argument("--profile-by-stage", action="store_true")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
|
|
||||||
|
def send_single(
|
||||||
|
args,
|
||||||
|
batch_size: int,
|
||||||
|
profile: bool = False,
|
||||||
|
profile_steps: int = 3,
|
||||||
|
profile_by_stage: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
prompt = [PROMPT_1] * batch_size
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"frequency_penalty": args.frequency_penalty,
|
||||||
|
"presence_penalty": args.presence_penalty,
|
||||||
|
},
|
||||||
|
"return_logprob": args.return_logprob,
|
||||||
|
"stream": args.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
run_profile(
|
||||||
|
base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json=json_data,
|
||||||
|
stream=args.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.stream:
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
ret = json.loads(chunk[5:].strip("\n"))
|
||||||
|
else:
|
||||||
|
ret = response.json()
|
||||||
|
ret = ret[0]
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(ret)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
return ret["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def send_mixed(args, batch_size: int):
|
||||||
|
num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10)
|
||||||
|
num_prompt_1 = random.randint(1, batch_size - num_long_prompt)
|
||||||
|
num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"text": [PROMPT_1] * num_prompt_1
|
||||||
|
+ [PROMPT_2] * num_prompt_2
|
||||||
|
+ [LONG_PROMPT] * num_long_prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"frequency_penalty": args.frequency_penalty,
|
||||||
|
"presence_penalty": args.presence_penalty,
|
||||||
|
},
|
||||||
|
"return_logprob": args.return_logprob,
|
||||||
|
"stream": args.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"http://{args.host}:{args.port}/generate",
|
||||||
|
json=json_data,
|
||||||
|
stream=args.stream,
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(ret)
|
||||||
|
return -1, -1, -1
|
||||||
|
|
||||||
|
prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
|
||||||
|
prompt_2_ret = [
|
||||||
|
ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
|
||||||
|
]
|
||||||
|
long_prompt_ret = [
|
||||||
|
ret[i]["text"]
|
||||||
|
for i in range(
|
||||||
|
num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return prompt_1_ret, prompt_2_ret, long_prompt_ret
|
||||||
|
|
||||||
|
|
||||||
|
def send_prefix(args, batch_size: int, prompts: List[str]):
|
||||||
|
requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
||||||
|
|
||||||
|
batch_data = []
|
||||||
|
sampled_indices = []
|
||||||
|
for _ in range(batch_size):
|
||||||
|
sampled_index = random.randint(0, len(prompts) - 1)
|
||||||
|
sampled_indices.append(sampled_index)
|
||||||
|
batch_data.append(prompts[sampled_index])
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"text": batch_data,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"frequency_penalty": args.frequency_penalty,
|
||||||
|
"presence_penalty": args.presence_penalty,
|
||||||
|
},
|
||||||
|
"return_logprob": args.return_logprob,
|
||||||
|
"stream": args.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"http://{args.host}:{args.port}/generate",
|
||||||
|
json=json_data,
|
||||||
|
stream=args.stream,
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(ret)
|
||||||
|
return -1, -1, -1
|
||||||
|
|
||||||
|
ret_dict = {i: [] for i in range(len(prompts))}
|
||||||
|
for i in range(batch_size):
|
||||||
|
ret_dict[sampled_indices[i]].append(ret[i]["text"])
|
||||||
|
|
||||||
|
return ret_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic(args):
|
||||||
|
# First do some warmups
|
||||||
|
for i in range(3):
|
||||||
|
send_single(args, 16, args.profile)
|
||||||
|
|
||||||
|
if args.test_mode == "single":
|
||||||
|
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
|
||||||
|
texts = []
|
||||||
|
for i in range(1, args.n_trials + 1):
|
||||||
|
batch_size = i
|
||||||
|
text = send_single(args, batch_size, args.profile)
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
print(f"Trial {i} with batch size {batch_size}: {text}")
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
|
||||||
|
elif args.test_mode == "mixed":
|
||||||
|
# In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
|
||||||
|
output_prompt_1 = []
|
||||||
|
output_prompt_2 = []
|
||||||
|
output_long_prompt = []
|
||||||
|
for i in range(1, args.n_trials + 1):
|
||||||
|
batch_size = i
|
||||||
|
ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
|
||||||
|
output_prompt_1.extend(ret_prompt_1)
|
||||||
|
output_prompt_2.extend(ret_prompt_2)
|
||||||
|
output_long_prompt.extend(ret_long_prompt)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.test_mode == "prefix":
|
||||||
|
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
||||||
|
len_prefix = [1, 511, 2048, 4097]
|
||||||
|
num_prompts = len(len_prefix)
|
||||||
|
outputs = {i: [] for i in range(4)}
|
||||||
|
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
||||||
|
for i in range(1, args.n_trials + 1):
|
||||||
|
batch_size = i
|
||||||
|
ret_dict = send_prefix(args, batch_size, prompts)
|
||||||
|
msg = f"Testing Trial {i} with batch size {batch_size},"
|
||||||
|
for i in range(num_prompts):
|
||||||
|
msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
|
||||||
|
print(msg)
|
||||||
|
for i in range(num_prompts):
|
||||||
|
outputs[i].extend(ret_dict[i])
|
||||||
|
|
||||||
|
for i in range(num_prompts):
|
||||||
|
print(
|
||||||
|
f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid test mode: {args.test_mode}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
BenchArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_deterministic(args)
|
||||||
Reference in New Issue
Block a user