diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e68c2e1b9..44a4498fa 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -19,6 +19,7 @@ import importlib import importlib.resources import logging import pkgutil +import warnings from functools import lru_cache from typing import Optional, Type @@ -121,7 +122,11 @@ class ModelRunner: # Load the model and create memory pool self.load_model() - self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs) + self.init_memory_pool( + total_gpu_memory, + server_args.max_num_reqs, + server_args.max_total_tokens, + ) self.init_cublas() self.init_flash_infer() @@ -203,8 +208,18 @@ class ModelRunner: max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token - def init_memory_pool(self, total_gpu_memory, max_num_reqs=None): + def init_memory_pool( + self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None + ): self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + if max_total_tokens is not None: + if max_total_tokens > self.max_total_num_tokens: + warnings.warn( + f"max_total_tokens={max_total_tokens} is larger than the profiled value " + f"{self.max_total_num_tokens}. " + f"Use the profiled value instead." + ) + self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) if self.max_total_num_tokens <= 0: raise RuntimeError( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68f6db7cb..ab4a350cf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,6 +44,7 @@ class ServerArgs: max_prefill_tokens: Optional[int] = None max_running_requests: Optional[int] = None max_num_reqs: Optional[int] = None + max_total_tokens: Optional[int] = None schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 @@ -231,6 +232,12 @@ class ServerArgs: default=ServerArgs.max_num_reqs, help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", ) + parser.add_argument( + "--max-total-tokens", + type=int, + default=ServerArgs.max_total_tokens, + help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", + ) parser.add_argument( "--schedule-policy", type=str,