diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 45c9be37a..6ea6ff194 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode # This can prevent the server from being too conservative. # Note that this only clips the estimation in the scheduler but does not change the stop # condition. The request can still generate tokens until it hits the unclipped max_new_tokens. -CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) +CLIP_MAX_NEW_TOKENS_ESTIMATION = int( + os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") +) class SchedulePolicy: @@ -146,7 +148,7 @@ class PrefillAdder: [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), - CLIP_MAX_NEW_TOKENS, + CLIP_MAX_NEW_TOKENS_ESTIMATION, ) * self.new_token_ratio for r in running_batch.reqs @@ -186,7 +188,7 @@ class PrefillAdder: len(req.prefix_indices), req.extend_input_len, ( - min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION) if not truncated else 0 ), @@ -258,7 +260,7 @@ class PrefillAdder: self._prefill_one_req( 0, req.extend_input_len, - min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), ) else: # Chunked prefill @@ -276,7 +278,7 @@ class PrefillAdder: return self.add_one_req_ignore_eos(req) total_tokens = req.extend_input_len + min( - req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS + req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION ) input_tokens = req.extend_input_len prefix_len = len(req.prefix_indices) @@ -302,7 +304,10 @@ class PrefillAdder: self._prefill_one_req( prefix_len, input_tokens, - min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), + min( + req.sampling_params.max_new_tokens, + CLIP_MAX_NEW_TOKENS_ESTIMATION, + ), ) else: # Chunked prefill diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 8f9553b5a..05e2812eb 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -113,12 +113,15 @@ class CudaGraphRunner: self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder # Batch sizes to capture - if self.model_runner.server_args.disable_cuda_graph_padding: + if model_runner.server_args.disable_cuda_graph_padding: self.capture_bs = list(range(1, 32)) + [64, 128] else: - self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)] + self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [ - bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size + bs + for bs in self.capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= model_runner.server_args.max_cuda_graph_bs ] self.compile_bs = ( [ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 83e6d2ded..df632ac78 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -120,6 +120,7 @@ class ServerArgs: enable_mixed_chunk: bool = False enable_torch_compile: bool = False max_torch_compile_bs: int = 32 + max_cuda_graph_bs: int = 160 torchao_config: str = "" enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False @@ -624,6 +625,12 @@ class ServerArgs: default=ServerArgs.max_torch_compile_bs, help="Set the maximum batch size when using torch compile.", ) + parser.add_argument( + "--max-cuda-graph-bs", + type=int, + default=ServerArgs.max_cuda_graph_bs, + help="Set the maximum batch size for cuda graph.", + ) parser.add_argument( "--torchao-config", type=str, diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 8efb83e7b..24c011c75 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=("--max-total-token", "1024", "--context-len", "8192"), - env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ}, + env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ}, return_stdout_stderr=(cls.stdout, cls.stderr), ) cls.base_url += "/v1"