diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 629af6a2a..4e8e90ec4 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -16,9 +16,9 @@ class GlobalConfig: self.wait_for_new_request_delay = 0.0006 # Runtime constants: New generation token ratio estimation - self.base_new_token_ratio = 0.4 + self.init_new_token_ratio = 0.7 self.base_min_new_token_ratio = 0.2 - self.new_token_ratio_decay = 0.0001 + self.new_token_ratio_decay = 0.001 self.new_token_ratio_recovery = 0.05 # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. @@ -27,6 +27,7 @@ class GlobalConfig: # Runtime constants: others self.num_continue_decode_steps = 10 + self.retract_decode_steps = 20 self.flashinfer_workspace_size = 192 * 1024 * 1024 # Output tokenization configs diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index e7c5ab5f7..d22f4edb9 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -9,6 +9,7 @@ import numpy as np import torch from flashinfer.sampling import top_k_top_p_sampling_from_probs +from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.managers.controller.radix_cache import RadixCache @@ -431,7 +432,8 @@ class Batch: def retract_decode(self): sorted_indices = [i for i in range(len(self.reqs))] - # TODO(lsyin): improve the priority of retraction + + # TODO(lsyin): improve retraction policy for radix cache sorted_indices.sort( key=lambda i: ( len(self.reqs[i].output_ids), @@ -443,7 +445,17 @@ class Batch: retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - while self.token_to_kv_pool.available_size() < len(self.reqs): + while ( + self.token_to_kv_pool.available_size() + < len(sorted_indices) * global_config.retract_decode_steps + ): + if len(sorted_indices) == 1: + # Corner case: only one request left + assert ( + self.token_to_kv_pool.available_size() > 0 + ), "No space left for only one request" + break + idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) @@ -468,7 +480,16 @@ class Batch: self.filter_batch(sorted_indices) - return retracted_reqs + # Reqs in batch are filtered + total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs) + total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs) + + new_estimate_ratio = ( + total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs) + ) / total_max_new_tokens + new_estimate_ratio = min(1.0, new_estimate_ratio) + + return retracted_reqs, new_estimate_ratio def check_for_jump_forward(self, model_runner): jump_forward_reqs = [] diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 2563c2912..f5401bc62 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -228,6 +228,7 @@ class ModelTpServer: break else: self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio def print_stats(self): num_used = self.max_total_num_tokens - ( @@ -536,9 +537,10 @@ class ModelTpServer: # Check if decode out of memory if not batch.check_decode_mem(): old_ratio = self.new_token_ratio - self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0) - retracted_reqs = batch.retract_decode() + retracted_reqs, new_token_ratio = batch.retract_decode() + self.new_token_ratio = new_token_ratio + logger.info( "decode out of memory happened, " f"#retracted_reqs: {len(retracted_reqs)}, "