Use simulate acc len from sglang.environ (#10771)
This commit is contained in:
@@ -124,6 +124,8 @@ class Envs:
|
||||
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
||||
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
||||
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
||||
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
|
||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||
|
||||
# Model Parallel
|
||||
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
||||
|
||||
@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import prepare_server_args
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
MOVE_ENVS_WARN = """
|
||||
########################################################################
|
||||
# For contributors and developers: #
|
||||
# Please move environment variable definitions to 'sglang/environ.py' #
|
||||
# using the following pattern: #
|
||||
# SGLANG_XXX = EnvBool(False) #
|
||||
# #
|
||||
########################################################################
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
|
||||
from sglang.srt.server_args import print_deprecated_warning
|
||||
|
||||
print_deprecated_warning(MOVE_ENVS_WARN)
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.environ import envs
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Simulate acceptance length for benchmarking purposes
|
||||
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
||||
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
||||
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
|
||||
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
|
||||
|
||||
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||
|
||||
@@ -500,13 +501,12 @@ class EagleVerifyInput:
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
if SIMULATE_ACC_LEN:
|
||||
if SIMULATE_ACC_LEN > 0.0:
|
||||
# Do simulation
|
||||
accept_index = _generate_simulated_accept_index(
|
||||
accept_index=accept_index,
|
||||
predict=predict, # mutable
|
||||
accept_length=accept_length, # mutable
|
||||
simulate_acc_len=SIMULATE_ACC_LEN,
|
||||
bs=bs,
|
||||
spec_steps=self.spec_steps,
|
||||
)
|
||||
@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
accept_length,
|
||||
simulate_acc_len,
|
||||
bs,
|
||||
spec_steps,
|
||||
simulate_acc_len: float = SIMULATE_ACC_LEN,
|
||||
simulate_acc_method: str = SIMULATE_ACC_METHOD,
|
||||
):
|
||||
simulate_acc_len_float = float(simulate_acc_len)
|
||||
if SIMULATE_ACC_METHOD == "multinomial":
|
||||
assert simulate_acc_len > 0.0
|
||||
|
||||
if simulate_acc_method == "multinomial":
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len_float,
|
||||
mean=simulate_acc_len,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
elif SIMULATE_ACC_METHOD == "match-expected":
|
||||
elif simulate_acc_method == "match-expected":
|
||||
# multinomial sampling does not match the expected length
|
||||
# we keep it for the sake of compatibility of existing tests
|
||||
# but it's better to use "match-expected" for the cases that need to
|
||||
# match the expected length, One caveat is that this will only sample
|
||||
# either round down or round up of the expected length
|
||||
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
||||
lower = int(simulate_acc_len_float // 1)
|
||||
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
|
||||
lower = int(simulate_acc_len // 1)
|
||||
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||
if lower == upper:
|
||||
simulate_acc_len = lower
|
||||
else:
|
||||
weight_upper = simulate_acc_len_float - lower
|
||||
weight_upper = simulate_acc_len - lower
|
||||
weight_lower = 1.0 - weight_upper
|
||||
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
Reference in New Issue
Block a user