Improve the user control of new_token_ratio (#1811)
This commit is contained in:
@@ -14,9 +14,15 @@ class GlobalConfig:
|
|||||||
self.default_backend = None
|
self.default_backend = None
|
||||||
|
|
||||||
# Runtime constants: New generation token ratio estimation
|
# Runtime constants: New generation token ratio estimation
|
||||||
self.init_new_token_ratio = 0.7
|
self.default_init_new_token_ratio = float(
|
||||||
self.base_min_new_token_ratio = 0.1
|
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
||||||
self.new_token_ratio_decay = 0.001
|
)
|
||||||
|
self.default_min_new_token_ratio_factor = float(
|
||||||
|
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
||||||
|
)
|
||||||
|
self.default_new_token_ratio_decay_steps = float(
|
||||||
|
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
||||||
|
)
|
||||||
|
|
||||||
# Runtime constants: others
|
# Runtime constants: others
|
||||||
self.retract_decode_steps = 20
|
self.retract_decode_steps = 20
|
||||||
|
|||||||
@@ -254,13 +254,22 @@ class Scheduler:
|
|||||||
assert (
|
assert (
|
||||||
server_args.schedule_conservativeness >= 0
|
server_args.schedule_conservativeness >= 0
|
||||||
), "Invalid schedule_conservativeness"
|
), "Invalid schedule_conservativeness"
|
||||||
self.min_new_token_ratio = min(
|
|
||||||
global_config.base_min_new_token_ratio
|
self.init_new_token_ratio = min(
|
||||||
|
global_config.default_init_new_token_ratio
|
||||||
* server_args.schedule_conservativeness,
|
* server_args.schedule_conservativeness,
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
self.new_token_ratio = self.min_new_token_ratio
|
self.min_new_token_ratio = min(
|
||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.init_new_token_ratio
|
||||||
|
* global_config.default_min_new_token_ratio_factor,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
self.new_token_ratio_decay = (
|
||||||
|
self.init_new_token_ratio - self.min_new_token_ratio
|
||||||
|
) / global_config.default_new_token_ratio_decay_steps
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
# Init profiler
|
# Init profiler
|
||||||
@@ -307,7 +316,7 @@ class Scheduler:
|
|||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
else:
|
else:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -334,7 +343,7 @@ class Scheduler:
|
|||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
|||||||
@@ -121,13 +121,13 @@ class CudaGraphRunner:
|
|||||||
bs
|
bs
|
||||||
for bs in self.capture_bs
|
for bs in self.capture_bs
|
||||||
if bs <= model_runner.req_to_token_pool.size
|
if bs <= model_runner.req_to_token_pool.size
|
||||||
and bs <= model_runner.server_args.max_cuda_graph_bs
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
||||||
]
|
]
|
||||||
self.compile_bs = (
|
self.compile_bs = (
|
||||||
[
|
[
|
||||||
bs
|
bs
|
||||||
for bs in self.capture_bs
|
for bs in self.capture_bs
|
||||||
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
||||||
]
|
]
|
||||||
if self.use_torch_compile
|
if self.use_torch_compile
|
||||||
else []
|
else []
|
||||||
|
|||||||
@@ -119,8 +119,8 @@ class ServerArgs:
|
|||||||
enable_overlap_schedule: bool = False
|
enable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
max_torch_compile_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
max_cuda_graph_bs: int = 160
|
cuda_graph_max_bs: int = 160
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
@@ -620,15 +620,15 @@ class ServerArgs:
|
|||||||
help="Optimize the model with torch.compile. Experimental feature.",
|
help="Optimize the model with torch.compile. Experimental feature.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-torch-compile-bs",
|
"--torch-compile-max-bs",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.max_torch_compile_bs,
|
default=ServerArgs.torch_compile_max_bs,
|
||||||
help="Set the maximum batch size when using torch compile.",
|
help="Set the maximum batch size when using torch compile.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-cuda-graph-bs",
|
"--cuda-graph-max-bs",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.max_cuda_graph_bs,
|
default=ServerArgs.cuda_graph_max_bs,
|
||||||
help="Set the maximum batch size for cuda graph.",
|
help="Set the maximum batch size for cuda graph.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user