2024-07-28 23:07:12 +10:00
"""
Copyright 2023 - 2024 SGLang Team
Licensed under the Apache License , Version 2.0 ( the " License " ) ;
you may not use this file except in compliance with the License .
You may obtain a copy of the License at
http : / / www . apache . org / licenses / LICENSE - 2.0
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2024-05-11 20:55:00 -07:00
""" The arguments of the server. """
2024-01-08 04:37:50 +00:00
import argparse
import dataclasses
2024-08-13 17:01:26 -07:00
import logging
2024-05-31 23:33:34 -07:00
import random
2024-01-30 08:34:51 -08:00
from typing import List , Optional , Union
2024-01-08 04:37:50 +00:00
2024-08-13 17:01:26 -07:00
logger = logging . getLogger ( __name__ )
2024-01-08 04:37:50 +00:00
@dataclasses.dataclass
class ServerArgs :
2024-05-11 20:55:00 -07:00
# Model and tokenizer
2024-01-08 04:37:50 +00:00
model_path : str
tokenizer_path : Optional [ str ] = None
tokenizer_mode : str = " auto "
2024-08-10 03:14:13 +08:00
skip_tokenizer_init : bool = False
2024-06-27 23:30:39 -07:00
load_format : str = " auto "
dtype : str = " auto "
2024-08-26 08:38:11 +08:00
kv_cache_dtype : str = " auto "
2024-01-08 04:37:50 +00:00
trust_remote_code : bool = True
2024-05-11 20:55:00 -07:00
context_length : Optional [ int ] = None
2024-05-21 11:46:35 -07:00
quantization : Optional [ str ] = None
2024-08-02 08:13:51 +08:00
served_model_name : Optional [ str ] = None
2024-06-27 23:30:39 -07:00
chat_template : Optional [ str ] = None
2024-08-26 01:29:12 +08:00
is_embedding : bool = False
2024-05-11 20:55:00 -07:00
# Port
host : str = " 127.0.0.1 "
port : int = 30000
additional_ports : Optional [ Union [ List [ int ] , int ] ] = None
# Memory and scheduling
2024-01-15 01:15:53 -08:00
mem_fraction_static : Optional [ float ] = None
2024-05-26 12:51:45 -07:00
max_running_requests : Optional [ int ] = None
2024-07-30 13:33:55 -07:00
max_total_tokens : Optional [ int ] = None
2024-08-14 21:56:20 -07:00
chunked_prefill_size : int = 8192
2024-08-11 01:18:52 -07:00
max_prefill_tokens : int = 16384
2024-07-29 23:04:48 -07:00
schedule_policy : str = " lpm "
2024-07-23 11:52:50 -07:00
schedule_conservativeness : float = 1.0
2024-05-11 20:55:00 -07:00
# Other runtime options
tp_size : int = 1
2024-07-20 02:14:22 +10:00
stream_interval : int = 1
2024-05-31 23:33:34 -07:00
random_seed : Optional [ int ] = None
2024-05-11 20:55:00 -07:00
# Logging
log_level : str = " info "
2024-06-25 01:16:20 -07:00
log_level_http : Optional [ str ] = None
2024-05-12 06:41:32 -07:00
log_requests : bool = False
2024-04-09 23:27:31 +08:00
show_time_cost : bool = False
2024-03-11 20:06:52 +08:00
2024-05-11 20:55:00 -07:00
# Other
2024-08-04 13:35:44 -07:00
api_key : Optional [ str ] = None
2024-08-11 16:41:03 -07:00
file_storage_pth : str = " SGLang_storage "
2024-05-11 20:55:00 -07:00
2024-05-27 21:24:10 -07:00
# Data parallelism
dp_size : int = 1
load_balance_method : str = " round_robin "
2024-09-09 04:14:11 -07:00
# Distributed args
nccl_init_addr : Optional [ str ] = None
nnodes : int = 1
node_rank : Optional [ int ] = None
# Model override args in JSON
json_model_override_args : str = " {} "
2024-05-11 20:55:00 -07:00
# Optimization/debug options
2024-09-10 17:11:16 -07:00
attention_backend : str = " flashinfer "
sampling_backend : str = " flashinfer "
2024-07-03 23:19:33 -07:00
disable_flashinfer : bool = False
2024-07-27 20:18:56 -07:00
disable_flashinfer_sampling : bool = False
2024-05-11 20:55:00 -07:00
disable_radix_cache : bool = False
2024-02-08 00:50:12 +08:00
disable_regex_jump_forward : bool = False
2024-07-12 13:00:03 -07:00
disable_cuda_graph : bool = False
2024-08-20 22:35:05 -07:00
disable_cuda_graph_padding : bool = False
2024-02-08 00:50:12 +08:00
disable_disk_cache : bool = False
2024-08-24 08:02:23 -07:00
disable_custom_all_reduce : bool = False
2024-08-16 02:13:00 -07:00
enable_mixed_chunk : bool = False
2024-07-20 18:34:37 -07:00
enable_torch_compile : bool = False
2024-09-09 05:32:41 -07:00
torchao_config : str = " "
2024-07-06 23:34:10 -07:00
enable_p2p_check : bool = False
2024-08-05 01:40:33 +08:00
enable_mla : bool = False
2024-08-24 08:02:23 -07:00
triton_attention_reduce_in_fp32 : bool = False
2024-01-08 04:37:50 +00:00
def __post_init__ ( self ) :
2024-09-10 17:11:16 -07:00
# Set missing default values
2024-01-08 04:37:50 +00:00
if self . tokenizer_path is None :
self . tokenizer_path = self . model_path
2024-08-02 08:13:51 +08:00
if self . served_model_name is None :
self . served_model_name = self . model_path
2024-08-11 01:18:52 -07:00
if self . chunked_prefill_size < = 0 :
# Disable chunked prefill
self . chunked_prefill_size = None
2024-09-10 17:11:16 -07:00
# Mem fraction depends on the tensor parallelism size
2024-01-15 01:15:53 -08:00
if self . mem_fraction_static is None :
2024-07-15 07:10:51 -07:00
if self . tp_size > = 16 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.79
2024-07-15 07:10:51 -07:00
elif self . tp_size > = 8 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.83
2024-01-17 04:43:17 -08:00
elif self . tp_size > = 4 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.85
2024-01-17 04:43:17 -08:00
elif self . tp_size > = 2 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.87
2024-07-22 03:19:24 -07:00
else :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.88
2024-08-11 01:18:52 -07:00
2024-01-30 08:34:51 -08:00
if isinstance ( self . additional_ports , int ) :
self . additional_ports = [ self . additional_ports ]
elif self . additional_ports is None :
self . additional_ports = [ ]
2024-01-08 04:37:50 +00:00
2024-05-31 23:33:34 -07:00
if self . random_seed is None :
self . random_seed = random . randint ( 0 , 1 << 30 )
2024-09-10 17:11:16 -07:00
# Deprecation warnings
if self . disable_flashinfer :
logger . warning (
" The option ' --disable-flashinfer ' will be deprecated in the next release. "
" Please use ' --attention-backend triton ' instead. "
)
if self . disable_flashinfer_sampling :
logger . warning (
" The option ' --disable-flashinfer-sampling ' will be deprecated in the next release. "
" Please use ' --sampling-backend pytorch ' instead. "
)
# Model-specific patches
if " Alibaba-NLP/gte-Qwen2-1.5B-instruct " == self . model_path :
logger . info (
" Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True "
)
self . trust_remote_code = False
if " gemma-2 " in self . model_path . lower ( ) :
logger . info ( " When using sliding window in gemma-2, turn on flashinfer. " )
self . attention_backend = " flashinfer "
2024-01-08 04:37:50 +00:00
@staticmethod
def add_cli_args ( parser : argparse . ArgumentParser ) :
parser . add_argument (
" --model-path " ,
type = str ,
help = " The path of the model weights. This can be a local folder or a Hugging Face repo ID. " ,
required = True ,
)
parser . add_argument (
" --tokenizer-path " ,
type = str ,
default = ServerArgs . tokenizer_path ,
help = " The path of the tokenizer. " ,
)
2024-05-14 07:57:00 +08:00
parser . add_argument (
" --host " , type = str , default = ServerArgs . host , help = " The host of the server. "
)
parser . add_argument (
" --port " , type = int , default = ServerArgs . port , help = " The port of the server. "
)
2024-01-30 08:34:51 -08:00
parser . add_argument (
" --additional-ports " ,
type = int ,
nargs = " * " ,
default = [ ] ,
2024-06-25 01:16:20 -07:00
help = " The additional ports specified for the server. " ,
2024-01-30 08:34:51 -08:00
)
2024-06-27 23:30:39 -07:00
parser . add_argument (
" --tokenizer-mode " ,
type = str ,
default = ServerArgs . tokenizer_mode ,
choices = [ " auto " , " slow " ] ,
help = " Tokenizer mode. ' auto ' will use the fast "
" tokenizer if available, and ' slow ' will "
" always use the slow tokenizer. " ,
)
2024-08-10 03:14:13 +08:00
parser . add_argument (
" --skip-tokenizer-init " ,
action = " store_true " ,
help = " If set, skip init tokenizer and pass input_ids in generate request " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --load-format " ,
type = str ,
default = ServerArgs . load_format ,
choices = [ " auto " , " pt " , " safetensors " , " npcache " , " dummy " ] ,
help = " The format of the model weights to load. "
' " auto " will try to load the weights in the safetensors format '
" and fall back to the pytorch bin format if safetensors format "
" is not available. "
' " pt " will load the weights in the pytorch bin format. '
' " safetensors " will load the weights in the safetensors format. '
' " npcache " will load the weights in pytorch format and store '
" a numpy cache to speed up the loading. "
' " dummy " will initialize the weights with random values, '
" which is mainly for profiling. " ,
)
parser . add_argument (
2024-06-27 23:30:39 -07:00
" --dtype " ,
2024-01-18 23:43:09 -08:00
type = str ,
2024-06-27 23:30:39 -07:00
default = ServerArgs . dtype ,
2024-07-05 10:06:17 -07:00
choices = [ " auto " , " half " , " float16 " , " bfloat16 " , " float " , " float32 " ] ,
help = " Data type for model weights and activations. \n \n "
2024-06-27 23:30:39 -07:00
' * " auto " will use FP16 precision for FP32 and FP16 models, and '
2024-07-05 10:06:17 -07:00
" BF16 precision for BF16 models. \n "
2024-06-27 23:30:39 -07:00
' * " half " for FP16. Recommended for AWQ quantization. \n '
' * " float16 " is the same as " half " . \n '
' * " bfloat16 " for a balance between precision and range. \n '
' * " float " is shorthand for FP32 precision. \n '
2024-07-05 10:06:17 -07:00
' * " float32 " for FP32 precision. ' ,
)
2024-08-26 08:38:11 +08:00
parser . add_argument (
" --kv-cache-dtype " ,
type = str ,
default = ServerArgs . kv_cache_dtype ,
choices = [ " auto " , " fp8_e5m2 " ] ,
help = ' Data type for kv cache storage. " auto " will use model data type. " fp8_e5m2 " is supported for CUDA 11.8+. ' ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --trust-remote-code " ,
action = " store_true " ,
help = " Whether or not to allow for custom models defined on the Hub in their own modeling files. " ,
)
2024-05-11 20:55:00 -07:00
parser . add_argument (
" --context-length " ,
type = int ,
default = ServerArgs . context_length ,
help = " The model ' s maximum context length. Defaults to None (will use the value from the model ' s config.json instead). " ,
)
2024-05-21 11:46:35 -07:00
parser . add_argument (
" --quantization " ,
type = str ,
default = ServerArgs . quantization ,
2024-07-19 09:54:01 -07:00
choices = [
" awq " ,
" fp8 " ,
" gptq " ,
" marlin " ,
" gptq_marlin " ,
2024-07-30 02:04:51 -07:00
" awq_marlin " ,
2024-07-19 09:54:01 -07:00
" squeezellm " ,
" bitsandbytes " ,
] ,
2024-05-21 11:46:35 -07:00
help = " The quantization method. " ,
)
2024-08-02 08:13:51 +08:00
parser . add_argument (
" --served-model-name " ,
type = str ,
default = ServerArgs . served_model_name ,
help = " Override the model name returned by the v1/models endpoint in OpenAI API server. " ,
)
2024-06-27 23:30:39 -07:00
parser . add_argument (
" --chat-template " ,
type = str ,
default = ServerArgs . chat_template ,
help = " The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. " ,
)
2024-09-10 17:11:16 -07:00
parser . add_argument (
" --is-embedding " ,
action = " store_true " ,
help = " Whether to use a CausalLM as an embedding model. " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --mem-fraction-static " ,
type = float ,
default = ServerArgs . mem_fraction_static ,
2024-01-17 18:37:02 -08:00
help = " The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-26 12:51:45 -07:00
parser . add_argument (
" --max-running-requests " ,
type = int ,
default = ServerArgs . max_running_requests ,
help = " The maximum number of running requests. " ,
)
2024-07-30 13:33:55 -07:00
parser . add_argument (
" --max-total-tokens " ,
type = int ,
default = ServerArgs . max_total_tokens ,
2024-09-10 17:11:16 -07:00
help = " The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
" This option is typically used for development and debugging purposes. " ,
2024-07-30 13:33:55 -07:00
)
2024-08-11 01:18:52 -07:00
parser . add_argument (
" --chunked-prefill-size " ,
type = int ,
default = ServerArgs . chunked_prefill_size ,
help = " The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill " ,
)
parser . add_argument (
" --max-prefill-tokens " ,
type = int ,
default = ServerArgs . max_prefill_tokens ,
help = " The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model ' s maximum context length. " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-07-29 23:04:48 -07:00
" --schedule-policy " ,
2024-01-08 04:37:50 +00:00
type = str ,
2024-07-29 23:04:48 -07:00
default = ServerArgs . schedule_policy ,
2024-05-13 12:47:13 +08:00
choices = [ " lpm " , " random " , " fcfs " , " dfs-weight " ] ,
2024-07-29 23:04:48 -07:00
help = " The scheduling policy of the requests. " ,
2024-01-08 04:37:50 +00:00
)
2024-01-17 18:37:02 -08:00
parser . add_argument (
" --schedule-conservativeness " ,
type = float ,
default = ServerArgs . schedule_conservativeness ,
2024-01-20 03:01:15 +08:00
help = " How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently. " ,
2024-01-17 18:37:02 -08:00
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-08-06 02:12:53 +08:00
" --tensor-parallel-size " ,
2024-05-11 20:55:00 -07:00
" --tp-size " ,
2024-01-08 04:37:50 +00:00
type = int ,
2024-05-11 20:55:00 -07:00
default = ServerArgs . tp_size ,
2024-06-25 01:16:20 -07:00
help = " The tensor parallelism size. " ,
2024-03-11 04:43:39 -07:00
)
2024-01-17 02:54:41 -08:00
parser . add_argument (
" --stream-interval " ,
type = int ,
2024-01-17 16:38:20 -08:00
default = ServerArgs . stream_interval ,
2024-01-29 17:05:42 -08:00
help = " The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher " ,
2024-01-17 02:54:41 -08:00
)
2024-05-11 20:55:00 -07:00
parser . add_argument (
" --random-seed " ,
type = int ,
default = ServerArgs . random_seed ,
2024-06-25 01:16:20 -07:00
help = " The random seed. " ,
2024-05-11 20:55:00 -07:00
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --log-level " ,
type = str ,
default = ServerArgs . log_level ,
2024-06-25 01:16:20 -07:00
help = " The logging level of all loggers. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-12 06:41:32 -07:00
parser . add_argument (
2024-06-25 01:16:20 -07:00
" --log-level-http " ,
type = str ,
default = ServerArgs . log_level_http ,
help = " The logging level of HTTP server. If not set, reuse --log-level by default. " ,
2024-05-12 06:41:32 -07:00
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-06-25 01:16:20 -07:00
" --log-requests " ,
2024-01-08 04:37:50 +00:00
action = " store_true " ,
2024-06-25 01:16:20 -07:00
help = " Log the inputs and outputs of all requests. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-11 20:55:00 -07:00
parser . add_argument (
" --show-time-cost " ,
action = " store_true " ,
2024-07-19 09:54:01 -07:00
help = " Show time cost of custom marks. " ,
2024-05-11 20:55:00 -07:00
)
2024-04-09 23:27:31 +08:00
parser . add_argument (
" --api-key " ,
type = str ,
default = ServerArgs . api_key ,
2024-08-04 13:35:44 -07:00
help = " Set API key of the server. It is also used in the OpenAI API compatible server. " ,
2024-04-09 23:27:31 +08:00
)
2024-07-30 04:07:18 +08:00
parser . add_argument (
" --file-storage-pth " ,
type = str ,
default = ServerArgs . file_storage_pth ,
help = " The path of the file storage in backend. " ,
)
2024-05-11 20:55:00 -07:00
2024-05-27 21:24:10 -07:00
# Data parallelism
parser . add_argument (
2024-08-06 02:12:53 +08:00
" --data-parallel-size " ,
2024-05-27 21:24:10 -07:00
" --dp-size " ,
type = int ,
default = ServerArgs . dp_size ,
2024-06-25 01:16:20 -07:00
help = " The data parallelism size. " ,
2024-05-27 21:24:10 -07:00
)
parser . add_argument (
" --load-balance-method " ,
type = str ,
default = ServerArgs . load_balance_method ,
2024-06-25 01:16:20 -07:00
help = " The load balancing strategy for data parallelism. " ,
2024-05-27 21:24:10 -07:00
choices = [
" round_robin " ,
" shortest_queue " ,
] ,
)
2024-06-17 20:41:24 -07:00
# Multi-node distributed serving args
parser . add_argument (
" --nccl-init-addr " ,
type = str ,
2024-07-05 10:06:17 -07:00
help = " The nccl init address of multi-node server. " ,
2024-06-17 20:41:24 -07:00
)
parser . add_argument (
2024-07-29 03:32:58 -07:00
" --nnodes " , type = int , default = ServerArgs . nnodes , help = " The number of nodes. "
2024-06-17 20:41:24 -07:00
)
2024-07-05 10:06:17 -07:00
parser . add_argument ( " --node-rank " , type = int , help = " The node rank. " )
2024-06-17 20:41:24 -07:00
2024-09-09 04:14:11 -07:00
# Model override args
parser . add_argument (
" --json-model-override-args " ,
type = str ,
help = " A dictionary in JSON string format used to override default model configurations. " ,
default = ServerArgs . json_model_override_args ,
)
2024-05-11 20:55:00 -07:00
# Optimization/debug options
2024-09-10 17:11:16 -07:00
parser . add_argument (
" --attention-backend " ,
type = str ,
choices = [ " flashinfer " , " triton " ] ,
default = ServerArgs . attention_backend ,
help = " Choose the kernels for attention layers. " ,
)
parser . add_argument (
" --sampling-backend " ,
type = str ,
choices = [ " flashinfer " , " pytorch " ] ,
default = ServerArgs . sampling_backend ,
help = " Choose the kernels for sampling layers. " ,
)
2024-04-09 23:27:31 +08:00
parser . add_argument (
2024-07-02 02:25:07 -07:00
" --disable-flashinfer " ,
2024-04-09 23:27:31 +08:00
action = " store_true " ,
2024-09-10 17:11:16 -07:00
help = " Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use ' --attention-backend triton ' instead. " ,
2024-07-27 20:18:56 -07:00
)
parser . add_argument (
" --disable-flashinfer-sampling " ,
action = " store_true " ,
2024-09-10 17:11:16 -07:00
help = " Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use ' --sampling-backend pytorch ' instead. " ,
2024-04-09 23:27:31 +08:00
)
2024-03-11 20:06:52 +08:00
parser . add_argument (
2024-05-11 20:55:00 -07:00
" --disable-radix-cache " ,
2024-03-11 20:06:52 +08:00
action = " store_true " ,
2024-07-19 09:54:01 -07:00
help = " Disable RadixAttention for prefix caching. " ,
2024-03-11 20:06:52 +08:00
)
2024-01-25 01:16:25 +08:00
parser . add_argument (
2024-02-08 00:50:12 +08:00
" --disable-regex-jump-forward " ,
2024-01-25 01:16:25 +08:00
action = " store_true " ,
2024-07-19 09:54:01 -07:00
help = " Disable regex jump-forward. " ,
2024-01-25 01:16:25 +08:00
)
2024-07-12 13:00:03 -07:00
parser . add_argument (
" --disable-cuda-graph " ,
action = " store_true " ,
help = " Disable cuda graph. " ,
)
2024-02-08 00:50:12 +08:00
parser . add_argument (
2024-08-20 22:35:05 -07:00
" --disable-cuda-graph-padding " ,
action = " store_true " ,
help = " Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. " ,
)
parser . add_argument (
2024-02-08 00:50:12 +08:00
" --disable-disk-cache " ,
action = " store_true " ,
help = " Disable disk cache to avoid possible crashes related to file system or high concurrency. " ,
)
2024-08-24 08:02:23 -07:00
parser . add_argument (
" --disable-custom-all-reduce " ,
action = " store_true " ,
default = False ,
help = " Disable the custom all-reduce kernel and fall back to NCCL. " ,
)
2024-08-16 02:13:00 -07:00
parser . add_argument (
" --enable-mixed-chunk " ,
action = " store_true " ,
2024-08-25 14:46:34 -07:00
help = " Enabling mixing prefill and decode in a batch when using chunked prefill. " ,
2024-08-16 02:13:00 -07:00
)
2024-07-20 18:34:37 -07:00
parser . add_argument (
" --enable-torch-compile " ,
action = " store_true " ,
2024-09-09 05:32:41 -07:00
help = " Optimize the model with torch.compile. Experimental feature. " ,
)
parser . add_argument (
" --torchao-config " ,
type = str ,
default = ServerArgs . torchao_config ,
help = " Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo " ,
2024-07-20 18:34:37 -07:00
)
2024-07-03 02:07:34 -07:00
parser . add_argument (
2024-07-27 20:18:56 -07:00
" --enable-p2p-check " ,
2024-07-03 02:07:34 -07:00
action = " store_true " ,
2024-07-27 20:18:56 -07:00
help = " Enable P2P check for GPU access, otherwise the p2p access is allowed by default. " ,
2024-07-03 02:07:34 -07:00
)
2024-08-05 01:40:33 +08:00
parser . add_argument (
" --enable-mla " ,
action = " store_true " ,
2024-08-21 19:24:36 -07:00
help = " Enable Multi-head Latent Attention (MLA) for DeepSeek-V2. " ,
2024-08-05 01:40:33 +08:00
)
2024-07-06 23:34:10 -07:00
parser . add_argument (
2024-08-24 08:02:23 -07:00
" --triton-attention-reduce-in-fp32 " ,
2024-07-06 23:34:10 -07:00
action = " store_true " ,
2024-07-27 20:18:56 -07:00
help = " Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16. "
2024-08-21 19:24:36 -07:00
" This only affects Triton attention kernels. " ,
2024-07-06 23:34:10 -07:00
)
2024-07-20 02:18:22 -07:00
parser . add_argument (
" --efficient-weight-load " ,
action = " store_true " ,
help = " Turn on memory efficient weight loading with quantization (quantize per layer during loading). " ,
)
2024-01-08 04:37:50 +00:00
@classmethod
def from_cli_args ( cls , args : argparse . Namespace ) :
2024-08-06 02:12:53 +08:00
args . tp_size = args . tensor_parallel_size
args . dp_size = args . data_parallel_size
2024-01-08 04:37:50 +00:00
attrs = [ attr . name for attr in dataclasses . fields ( cls ) ]
return cls ( * * { attr : getattr ( args , attr ) for attr in attrs } )
def url ( self ) :
return f " http:// { self . host } : { self . port } "
2024-07-27 19:03:40 -07:00
def check_server_args ( self ) :
assert (
self . tp_size % self . nnodes == 0
) , " tp_size must be divisible by number of nodes "
assert not (
self . dp_size > 1 and self . node_rank is not None
) , " multi-node data parallel is not supported "
2024-01-08 04:37:50 +00:00
2024-09-09 04:14:11 -07:00
def prepare_server_args ( argv : List [ str ] ) - > ServerArgs :
2024-09-09 02:14:25 -07:00
"""
Prepare the server arguments from the command line arguments .
Args :
args : The command line arguments . Typically , it should be ` sys . argv [ 1 : ] `
to ensure compatibility with ` parse_args ` when no arguments are passed .
Returns :
The server arguments .
"""
parser = argparse . ArgumentParser ( )
ServerArgs . add_cli_args ( parser )
2024-09-09 04:14:11 -07:00
raw_args = parser . parse_args ( argv )
2024-09-09 02:14:25 -07:00
server_args = ServerArgs . from_cli_args ( raw_args )
return server_args
2024-01-08 04:37:50 +00:00
@dataclasses.dataclass
class PortArgs :
tokenizer_port : int
2024-07-18 02:13:54 -07:00
controller_port : int
2024-01-08 04:37:50 +00:00
detokenizer_port : int
2024-07-18 02:13:54 -07:00
nccl_ports : List [ int ]