Expose dtype argument (#569)
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional, Type, Any
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -119,7 +119,7 @@ class InputMetadata:
|
|||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
data_type="float16",
|
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_extend_args(self):
|
def init_extend_args(self):
|
||||||
@@ -287,10 +287,11 @@ class ModelRunner:
|
|||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
tokenizer_mode=None,
|
tokenizer_mode=None,
|
||||||
trust_remote_code=self.server_args.trust_remote_code,
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
dtype=torch.float16,
|
dtype=self.server_args.dtype,
|
||||||
seed=42,
|
seed=42,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
|
self.dtype = vllm_model_config.dtype
|
||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||||
|
|
||||||
@@ -307,6 +308,7 @@ class ModelRunner:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu_id={self.gpu_id}] Load weight end. "
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||||
f"type={type(self.model).__name__}, "
|
f"type={type(self.model).__name__}, "
|
||||||
|
f"dtype={self.dtype}, "
|
||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,7 +318,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
|
||||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||||
1 - self.mem_fraction_static
|
1 - self.mem_fraction_static
|
||||||
)
|
)
|
||||||
@@ -337,7 +339,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.token_to_kv_pool = TokenToKVPool(
|
self.token_to_kv_pool = TokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=torch.float16,
|
dtype=self.dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ModelTpServer:
|
|||||||
f"[gpu_id={self.gpu_id}] "
|
f"[gpu_id={self.gpu_id}] "
|
||||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||||
f"context_len={self.model_config.context_len}, "
|
f"context_len={self.model_config.context_len}"
|
||||||
)
|
)
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -11,12 +11,13 @@ class ServerArgs:
|
|||||||
# Model and tokenizer
|
# Model and tokenizer
|
||||||
model_path: str
|
model_path: str
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
load_format: str = "auto"
|
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
chat_template: Optional[str] = None
|
load_format: str = "auto"
|
||||||
|
dtype: str = "auto"
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
|
||||||
# Port
|
# Port
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
@@ -107,6 +108,15 @@ class ServerArgs:
|
|||||||
default=[],
|
default=[],
|
||||||
help="The additional ports specified for the server.",
|
help="The additional ports specified for the server.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer-mode",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.tokenizer_mode,
|
||||||
|
choices=["auto", "slow"],
|
||||||
|
help="Tokenizer mode. 'auto' will use the fast "
|
||||||
|
"tokenizer if available, and 'slow' will "
|
||||||
|
"always use the slow tokenizer.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -124,20 +134,20 @@ class ServerArgs:
|
|||||||
"which is mainly for profiling.",
|
"which is mainly for profiling.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-mode",
|
"--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.tokenizer_mode,
|
default=ServerArgs.dtype,
|
||||||
choices=["auto", "slow"],
|
choices=[
|
||||||
help="Tokenizer mode. 'auto' will use the fast "
|
"auto", "half", "float16", "bfloat16", "float", "float32"
|
||||||
"tokenizer if available, and 'slow' will "
|
],
|
||||||
"always use the slow tokenizer.",
|
help='Data type for model weights and activations.\n\n'
|
||||||
)
|
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||||
parser.add_argument(
|
'BF16 precision for BF16 models.\n'
|
||||||
"--chat-template",
|
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
||||||
type=str,
|
'* "float16" is the same as "half".\n'
|
||||||
default=ServerArgs.chat_template,
|
'* "bfloat16" for a balance between precision and range.\n'
|
||||||
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
)
|
'* "float32" for FP32 precision.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -155,6 +165,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.quantization,
|
default=ServerArgs.quantization,
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.chat_template,
|
||||||
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
type=float,
|
type=float,
|
||||||
|
|||||||
Reference in New Issue
Block a user