Use simulate acc len from sglang.environ (#10771)
This commit is contained in:
@@ -124,6 +124,8 @@ class Envs:
|
|||||||
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
||||||
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
||||||
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
||||||
|
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
|
||||||
|
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||||
|
|
||||||
# Model Parallel
|
# Model Parallel
|
||||||
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
||||||
|
|||||||
@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
|
|||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import prepare_server_args
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
|
MOVE_ENVS_WARN = """
|
||||||
|
########################################################################
|
||||||
|
# For contributors and developers: #
|
||||||
|
# Please move environment variable definitions to 'sglang/environ.py' #
|
||||||
|
# using the following pattern: #
|
||||||
|
# SGLANG_XXX = EnvBool(False) #
|
||||||
|
# #
|
||||||
|
########################################################################
|
||||||
|
"""
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
server_args = prepare_server_args(sys.argv[1:])
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
|
|
||||||
|
from sglang.srt.server_args import print_deprecated_warning
|
||||||
|
|
||||||
|
print_deprecated_warning(MOVE_ENVS_WARN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
launch_server(server_args)
|
launch_server(server_args)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.environ import envs
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -23,7 +24,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
@@ -42,8 +43,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Simulate acceptance length for benchmarking purposes
|
# Simulate acceptance length for benchmarking purposes
|
||||||
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
|
||||||
SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
|
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
|
||||||
|
|
||||||
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||||
|
|
||||||
@@ -500,13 +501,12 @@ class EagleVerifyInput:
|
|||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SIMULATE_ACC_LEN:
|
if SIMULATE_ACC_LEN > 0.0:
|
||||||
# Do simulation
|
# Do simulation
|
||||||
accept_index = _generate_simulated_accept_index(
|
accept_index = _generate_simulated_accept_index(
|
||||||
accept_index=accept_index,
|
accept_index=accept_index,
|
||||||
predict=predict, # mutable
|
predict=predict, # mutable
|
||||||
accept_length=accept_length, # mutable
|
accept_length=accept_length, # mutable
|
||||||
simulate_acc_len=SIMULATE_ACC_LEN,
|
|
||||||
bs=bs,
|
bs=bs,
|
||||||
spec_steps=self.spec_steps,
|
spec_steps=self.spec_steps,
|
||||||
)
|
)
|
||||||
@@ -1131,14 +1131,16 @@ def _generate_simulated_accept_index(
|
|||||||
accept_index,
|
accept_index,
|
||||||
predict,
|
predict,
|
||||||
accept_length,
|
accept_length,
|
||||||
simulate_acc_len,
|
|
||||||
bs,
|
bs,
|
||||||
spec_steps,
|
spec_steps,
|
||||||
|
simulate_acc_len: float = SIMULATE_ACC_LEN,
|
||||||
|
simulate_acc_method: str = SIMULATE_ACC_METHOD,
|
||||||
):
|
):
|
||||||
simulate_acc_len_float = float(simulate_acc_len)
|
assert simulate_acc_len > 0.0
|
||||||
if SIMULATE_ACC_METHOD == "multinomial":
|
|
||||||
|
if simulate_acc_method == "multinomial":
|
||||||
simulated_values = torch.normal(
|
simulated_values = torch.normal(
|
||||||
mean=simulate_acc_len_float,
|
mean=simulate_acc_len,
|
||||||
std=1.0,
|
std=1.0,
|
||||||
size=(1,),
|
size=(1,),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@@ -1146,19 +1148,19 @@ def _generate_simulated_accept_index(
|
|||||||
# clamp simulated values to be between 1 and self.spec_steps
|
# clamp simulated values to be between 1 and self.spec_steps
|
||||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||||
simulate_acc_len = int(simulated_values.round().item())
|
simulate_acc_len = int(simulated_values.round().item())
|
||||||
elif SIMULATE_ACC_METHOD == "match-expected":
|
elif simulate_acc_method == "match-expected":
|
||||||
# multinomial sampling does not match the expected length
|
# multinomial sampling does not match the expected length
|
||||||
# we keep it for the sake of compatibility of existing tests
|
# we keep it for the sake of compatibility of existing tests
|
||||||
# but it's better to use "match-expected" for the cases that need to
|
# but it's better to use "match-expected" for the cases that need to
|
||||||
# match the expected length, One caveat is that this will only sample
|
# match the expected length, One caveat is that this will only sample
|
||||||
# either round down or round up of the expected length
|
# either round down or round up of the expected length
|
||||||
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
|
||||||
lower = int(simulate_acc_len_float // 1)
|
lower = int(simulate_acc_len // 1)
|
||||||
upper = lower + 1 if lower < spec_steps + 1 else lower
|
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||||
if lower == upper:
|
if lower == upper:
|
||||||
simulate_acc_len = lower
|
simulate_acc_len = lower
|
||||||
else:
|
else:
|
||||||
weight_upper = simulate_acc_len_float - lower
|
weight_upper = simulate_acc_len - lower
|
||||||
weight_lower = 1.0 - weight_upper
|
weight_lower = 1.0 - weight_upper
|
||||||
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user