Add --max-total-tokens (#840)

This commit is contained in:
Liangsheng Yin
2024-07-30 13:33:55 -07:00
committed by GitHub
parent 1edd4e07d6
commit 6b0f2e9088
2 changed files with 24 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
@@ -121,7 +122,11 @@ class ModelRunner:
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flash_infer()
@@ -203,8 +208,18 @@ class ModelRunner:
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
warnings.warn(
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
f"{self.max_total_num_tokens}. "
f"Use the profiled value instead."
)
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
if self.max_total_num_tokens <= 0:
raise RuntimeError(