Expose dtype argument (#569)
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Type, Any
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -119,7 +119,7 @@ class InputMetadata:
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type="float16",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
@@ -287,10 +287,11 @@ class ModelRunner:
|
||||
tokenizer=None,
|
||||
tokenizer_mode=None,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
dtype=torch.float16,
|
||||
dtype=self.server_args.dtype,
|
||||
seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
self.dtype = vllm_model_config.dtype
|
||||
if self.model_config.model_overide_args is not None:
|
||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||
|
||||
@@ -307,6 +308,7 @@ class ModelRunner:
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
@@ -316,7 +318,7 @@ class ModelRunner:
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
@@ -337,7 +339,7 @@ class ModelRunner:
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=torch.float16,
|
||||
dtype=self.dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
|
||||
@@ -120,7 +120,7 @@ class ModelTpServer:
|
||||
f"[gpu_id={self.gpu_id}] "
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
if self.tp_rank == 0:
|
||||
logger.info(
|
||||
|
||||
@@ -11,12 +11,13 @@ class ServerArgs:
|
||||
# Model and tokenizer
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
load_format: str = "auto"
|
||||
tokenizer_mode: str = "auto"
|
||||
chat_template: Optional[str] = None
|
||||
load_format: str = "auto"
|
||||
dtype: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
context_length: Optional[int] = None
|
||||
quantization: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
@@ -107,6 +108,15 @@ class ServerArgs:
|
||||
default=[],
|
||||
help="The additional ports specified for the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_mode,
|
||||
choices=["auto", "slow"],
|
||||
help="Tokenizer mode. 'auto' will use the fast "
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
@@ -124,20 +134,20 @@ class ServerArgs:
|
||||
"which is mainly for profiling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
"--dtype",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_mode,
|
||||
choices=["auto", "slow"],
|
||||
help="Tokenizer mode. 'auto' will use the fast "
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=ServerArgs.chat_template,
|
||||
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
|
||||
)
|
||||
default=ServerArgs.dtype,
|
||||
choices=[
|
||||
"auto", "half", "float16", "bfloat16", "float", "float32"
|
||||
],
|
||||
help='Data type for model weights and activations.\n\n'
|
||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||
'BF16 precision for BF16 models.\n'
|
||||
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
||||
'* "float16" is the same as "half".\n'
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.')
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
@@ -155,6 +165,12 @@ class ServerArgs:
|
||||
default=ServerArgs.quantization,
|
||||
help="The quantization method.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=ServerArgs.chat_template,
|
||||
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
|
||||
Reference in New Issue
Block a user