Add --max-total-tokens (#840)
This commit is contained in:
@@ -19,6 +19,7 @@ import importlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
|
||||
@@ -121,7 +122,11 @@ class ModelRunner:
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
|
||||
self.init_memory_pool(
|
||||
total_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flash_infer()
|
||||
|
||||
@@ -203,8 +208,18 @@ class ModelRunner:
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
|
||||
def init_memory_pool(
|
||||
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
|
||||
):
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
warnings.warn(
|
||||
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
||||
f"{self.max_total_num_tokens}. "
|
||||
f"Use the profiled value instead."
|
||||
)
|
||||
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
||||
|
||||
if self.max_total_num_tokens <= 0:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -44,6 +44,7 @@ class ServerArgs:
|
||||
max_prefill_tokens: Optional[int] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_num_reqs: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
|
||||
@@ -231,6 +232,12 @@ class ServerArgs:
|
||||
default=ServerArgs.max_num_reqs,
|
||||
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens",
|
||||
type=int,
|
||||
default=ServerArgs.max_total_tokens,
|
||||
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-policy",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user