From 632b7d8cc9880ff9a9b4b9500991329f5c199197 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 23 Sep 2025 12:59:50 +0800 Subject: [PATCH] Use simulate acc len from `sglang.environ` (#10771) --- python/sglang/environ.py | 2 ++ python/sglang/launch_server.py | 14 ++++++++++ python/sglang/srt/speculative/eagle_utils.py | 28 +++++++++++--------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/python/sglang/environ.py b/python/sglang/environ.py index e28120702..de0d52742 100644 --- a/python/sglang/environ.py +++ b/python/sglang/environ.py @@ -124,6 +124,8 @@ class Envs: SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False) SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False) SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False) + SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1) + SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") # Model Parallel SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index caae7b0f6..0808fad4c 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree +MOVE_ENVS_WARN = """ +######################################################################## +# For contributors and developers: # +# Please move environment variable definitions to 'sglang/environ.py' # +# using the following pattern: # +# SGLANG_XXX = EnvBool(False) # +# # +######################################################################## +""" + if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) + from sglang.srt.server_args import print_deprecated_warning + + print_deprecated_warning(MOVE_ENVS_WARN) + try: launch_server(server_args) finally: diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index e6c55df18..165363fb8 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -12,6 +12,7 @@ import torch.nn.functional as F import triton import triton.language as tl +from sglang.environ import envs from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 if is_cuda(): @@ -42,8 +43,8 @@ logger = logging.getLogger(__name__) # Simulate acceptance length for benchmarking purposes -SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN") -SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial") +SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 +SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly @@ -500,13 +501,12 @@ class EagleVerifyInput: deterministic=True, ) - if SIMULATE_ACC_LEN: + if SIMULATE_ACC_LEN > 0.0: # Do simulation accept_index = _generate_simulated_accept_index( accept_index=accept_index, predict=predict, # mutable accept_length=accept_length, # mutable - simulate_acc_len=SIMULATE_ACC_LEN, bs=bs, spec_steps=self.spec_steps, ) @@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index( accept_index, predict, accept_length, - simulate_acc_len, bs, spec_steps, + simulate_acc_len: float = SIMULATE_ACC_LEN, + simulate_acc_method: str = SIMULATE_ACC_METHOD, ): - simulate_acc_len_float = float(simulate_acc_len) - if SIMULATE_ACC_METHOD == "multinomial": + assert simulate_acc_len > 0.0 + + if simulate_acc_method == "multinomial": simulated_values = torch.normal( - mean=simulate_acc_len_float, + mean=simulate_acc_len, std=1.0, size=(1,), device="cpu", @@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index( # clamp simulated values to be between 1 and self.spec_steps simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) simulate_acc_len = int(simulated_values.round().item()) - elif SIMULATE_ACC_METHOD == "match-expected": + elif simulate_acc_method == "match-expected": # multinomial sampling does not match the expected length # we keep it for the sake of compatibility of existing tests # but it's better to use "match-expected" for the cases that need to # match the expected length, One caveat is that this will only sample # either round down or round up of the expected length - simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float)) - lower = int(simulate_acc_len_float // 1) + simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len)) + lower = int(simulate_acc_len // 1) upper = lower + 1 if lower < spec_steps + 1 else lower if lower == upper: simulate_acc_len = lower else: - weight_upper = simulate_acc_len_float - lower + weight_upper = simulate_acc_len - lower weight_lower = 1.0 - weight_upper probs = torch.tensor([weight_lower, weight_upper], device="cpu") sampled_index = torch.multinomial(probs, num_samples=1)