Add --max-total-tokens (#840)
This commit is contained in:
@@ -19,6 +19,7 @@ import importlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
|
||||
@@ -121,7 +122,11 @@ class ModelRunner:
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
|
||||
self.init_memory_pool(
|
||||
total_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flash_infer()
|
||||
|
||||
@@ -203,8 +208,18 @@ class ModelRunner:
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
|
||||
def init_memory_pool(
|
||||
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
|
||||
):
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
warnings.warn(
|
||||
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
||||
f"{self.max_total_num_tokens}. "
|
||||
f"Use the profiled value instead."
|
||||
)
|
||||
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
||||
|
||||
if self.max_total_num_tokens <= 0:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user