diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 8d7ac3c8b..8b7adf944 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -82,7 +82,8 @@ class ModelRpcServer(rpyc.Service): self.max_total_num_token = self.model_runner.max_total_num_token self.max_num_running_seq = self.max_total_num_token // 2 self.max_prefill_num_token = max( - self.model_config.context_len, self.max_total_num_token // 6 + self.model_config.context_len, + self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token, ) self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 30a182a3e..560c93b28 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -430,7 +430,8 @@ class Runtime: load_format: str = "auto", tokenizer_mode: str = "auto", trust_remote_code: bool = True, - mem_fraction_static: float = 0.9, + mem_fraction_static: float = ServerArgs.mem_fraction_static, + max_prefill_num_token: int = ServerArgs.max_prefill_num_token, tp_size: int = 1, model_mode: List[str] = (), schedule_heuristic: str = "lpm", @@ -451,6 +452,7 @@ class Runtime: tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, mem_fraction_static=mem_fraction_static, + max_prefill_num_token=max_prefill_num_token, tp_size=tp_size, model_mode=model_mode, schedule_heuristic=schedule_heuristic, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 39622967b..866a93ac2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -15,6 +15,7 @@ class ServerArgs: chat_template: Optional[str] = None trust_remote_code: bool = True mem_fraction_static: Optional[float] = None + max_prefill_num_token: Optional[int] = None tp_size: int = 1 model_mode: List[str] = () schedule_heuristic: str = "lpm" @@ -109,6 +110,12 @@ class ServerArgs: default=ServerArgs.mem_fraction_static, help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) + parser.add_argument( + "--max-prefill-num-token", + type=int, + default=ServerArgs.max_prefill_num_token, + help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length." + ) parser.add_argument( "--tp-size", type=int,