2024-07-28 23:07:12 +10:00
"""
Copyright 2023 - 2024 SGLang Team
Licensed under the Apache License , Version 2.0 ( the " License " ) ;
you may not use this file except in compliance with the License .
You may obtain a copy of the License at
http : / / www . apache . org / licenses / LICENSE - 2.0
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2024-05-11 20:55:00 -07:00
""" The arguments of the server. """
2024-01-08 04:37:50 +00:00
import argparse
import dataclasses
2024-08-13 17:01:26 -07:00
import logging
2024-05-31 23:33:34 -07:00
import random
2024-10-04 00:45:52 -07:00
import tempfile
from typing import List , Optional
2024-01-08 04:37:50 +00:00
2024-10-06 22:54:05 -07:00
from sglang . srt . utils import is_flashinfer_available , is_ipv6 , is_port_available
2024-09-17 00:43:52 -07:00
2024-08-13 17:01:26 -07:00
logger = logging . getLogger ( __name__ )
2024-01-08 04:37:50 +00:00
@dataclasses.dataclass
class ServerArgs :
2024-05-11 20:55:00 -07:00
# Model and tokenizer
2024-01-08 04:37:50 +00:00
model_path : str
tokenizer_path : Optional [ str ] = None
tokenizer_mode : str = " auto "
2024-08-10 03:14:13 +08:00
skip_tokenizer_init : bool = False
2024-06-27 23:30:39 -07:00
load_format : str = " auto "
2024-10-12 17:53:23 -07:00
trust_remote_code : bool = True
2024-06-27 23:30:39 -07:00
dtype : str = " auto "
2024-08-26 08:38:11 +08:00
kv_cache_dtype : str = " auto "
2024-05-21 11:46:35 -07:00
quantization : Optional [ str ] = None
2024-10-12 17:53:23 -07:00
context_length : Optional [ int ] = None
device : str = " cuda "
2024-08-02 08:13:51 +08:00
served_model_name : Optional [ str ] = None
2024-06-27 23:30:39 -07:00
chat_template : Optional [ str ] = None
2024-08-26 01:29:12 +08:00
is_embedding : bool = False
2024-05-11 20:55:00 -07:00
# Port
host : str = " 127.0.0.1 "
port : int = 30000
# Memory and scheduling
2024-01-15 01:15:53 -08:00
mem_fraction_static : Optional [ float ] = None
2024-05-26 12:51:45 -07:00
max_running_requests : Optional [ int ] = None
2024-07-30 13:33:55 -07:00
max_total_tokens : Optional [ int ] = None
2024-08-14 21:56:20 -07:00
chunked_prefill_size : int = 8192
2024-08-11 01:18:52 -07:00
max_prefill_tokens : int = 16384
2024-07-29 23:04:48 -07:00
schedule_policy : str = " lpm "
2024-07-23 11:52:50 -07:00
schedule_conservativeness : float = 1.0
2024-05-11 20:55:00 -07:00
# Other runtime options
tp_size : int = 1
2024-07-20 02:14:22 +10:00
stream_interval : int = 1
2024-05-31 23:33:34 -07:00
random_seed : Optional [ int ] = None
2024-09-16 13:29:18 -07:00
constrained_json_whitespace_pattern : Optional [ str ] = None
2024-11-08 02:19:41 -08:00
watchdog_timeout : float = 300
2024-05-11 20:55:00 -07:00
# Logging
log_level : str = " info "
2024-06-25 01:16:20 -07:00
log_level_http : Optional [ str ] = None
2024-05-12 06:41:32 -07:00
log_requests : bool = False
2024-04-09 23:27:31 +08:00
show_time_cost : bool = False
2024-11-06 12:42:53 +08:00
enable_metrics : bool = False
2024-11-08 02:19:41 -08:00
decode_log_interval : int = 40
2024-03-11 20:06:52 +08:00
2024-11-08 02:19:41 -08:00
# API related
2024-08-04 13:35:44 -07:00
api_key : Optional [ str ] = None
2024-08-11 16:41:03 -07:00
file_storage_pth : str = " SGLang_storage "
2024-10-16 20:49:22 +02:00
enable_cache_report : bool = False
2024-05-11 20:55:00 -07:00
2024-05-27 21:24:10 -07:00
# Data parallelism
dp_size : int = 1
load_balance_method : str = " round_robin "
2024-11-08 02:19:41 -08:00
# Multi-node distributed serving
2024-09-29 02:36:12 -07:00
dist_init_addr : Optional [ str ] = None
2024-09-09 04:14:11 -07:00
nnodes : int = 1
2024-09-29 02:36:12 -07:00
node_rank : int = 0
2024-09-09 04:14:11 -07:00
# Model override args in JSON
json_model_override_args : str = " {} "
2024-10-14 02:00:41 -07:00
# Double Sparsity
enable_double_sparsity : bool = False
ds_channel_config_path : str = None
ds_heavy_channel_num : int = 32
ds_heavy_token_num : int = 256
ds_heavy_channel_type : str = " qk "
ds_sparse_decode_threshold : int = 4096
2024-10-12 17:53:23 -07:00
# LoRA
lora_paths : Optional [ List [ str ] ] = None
max_loras_per_batch : int = 8
# Kernel backend
2024-09-11 11:44:26 -07:00
attention_backend : Optional [ str ] = None
sampling_backend : Optional [ str ] = None
2024-10-26 06:47:02 +09:00
grammar_backend : Optional [ str ] = " outlines "
2024-09-10 17:11:16 -07:00
2024-10-12 17:53:23 -07:00
# Optimization/debug options
2024-07-03 23:19:33 -07:00
disable_flashinfer : bool = False
2024-07-27 20:18:56 -07:00
disable_flashinfer_sampling : bool = False
2024-05-11 20:55:00 -07:00
disable_radix_cache : bool = False
2024-11-12 21:17:38 -08:00
disable_jump_forward : bool = False
2024-07-12 13:00:03 -07:00
disable_cuda_graph : bool = False
2024-08-20 22:35:05 -07:00
disable_cuda_graph_padding : bool = False
2024-02-08 00:50:12 +08:00
disable_disk_cache : bool = False
2024-08-24 08:02:23 -07:00
disable_custom_all_reduce : bool = False
2024-09-17 19:42:48 +08:00
disable_mla : bool = False
2024-10-12 17:53:23 -07:00
disable_penalizer : bool = False
2024-10-18 20:21:24 -07:00
disable_nan_detection : bool = False
2024-10-16 01:33:20 -07:00
enable_overlap_schedule : bool = False
2024-08-16 02:13:00 -07:00
enable_mixed_chunk : bool = False
2024-07-20 18:34:37 -07:00
enable_torch_compile : bool = False
2024-10-26 16:39:41 -07:00
torch_compile_max_bs : int = 32
cuda_graph_max_bs : int = 160
2024-09-09 05:32:41 -07:00
torchao_config : str = " "
2024-07-06 23:34:10 -07:00
enable_p2p_check : bool = False
2024-08-24 08:02:23 -07:00
triton_attention_reduce_in_fp32 : bool = False
2024-10-12 21:35:30 -07:00
num_continuous_decode_steps : int = 1
2024-11-08 02:19:41 -08:00
delete_ckpt_after_loading : bool = False
2024-01-08 04:37:50 +00:00
def __post_init__ ( self ) :
2024-09-10 17:11:16 -07:00
# Set missing default values
2024-01-08 04:37:50 +00:00
if self . tokenizer_path is None :
self . tokenizer_path = self . model_path
2024-08-02 08:13:51 +08:00
if self . served_model_name is None :
self . served_model_name = self . model_path
2024-08-11 01:18:52 -07:00
if self . chunked_prefill_size < = 0 :
# Disable chunked prefill
self . chunked_prefill_size = None
2024-09-10 17:11:16 -07:00
# Mem fraction depends on the tensor parallelism size
2024-01-15 01:15:53 -08:00
if self . mem_fraction_static is None :
2024-07-15 07:10:51 -07:00
if self . tp_size > = 16 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.79
2024-07-15 07:10:51 -07:00
elif self . tp_size > = 8 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.83
2024-01-17 04:43:17 -08:00
elif self . tp_size > = 4 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.85
2024-01-17 04:43:17 -08:00
elif self . tp_size > = 2 :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.87
2024-07-22 03:19:24 -07:00
else :
2024-07-30 01:58:31 -07:00
self . mem_fraction_static = 0.88
2024-08-11 01:18:52 -07:00
2024-05-31 23:33:34 -07:00
if self . random_seed is None :
self . random_seed = random . randint ( 0 , 1 << 30 )
2024-09-10 17:11:16 -07:00
# Deprecation warnings
if self . disable_flashinfer :
logger . warning (
" The option ' --disable-flashinfer ' will be deprecated in the next release. "
" Please use ' --attention-backend triton ' instead. "
)
2024-09-11 04:36:21 -07:00
self . attention_backend = " triton "
2024-09-10 17:11:16 -07:00
if self . disable_flashinfer_sampling :
logger . warning (
" The option ' --disable-flashinfer-sampling ' will be deprecated in the next release. "
" Please use ' --sampling-backend pytorch ' instead. "
)
2024-09-11 04:36:21 -07:00
self . sampling_backend = " pytorch "
2024-09-10 17:11:16 -07:00
2024-10-06 22:54:05 -07:00
if not is_flashinfer_available ( ) :
2024-09-17 00:43:52 -07:00
self . attention_backend = " triton "
self . sampling_backend = " pytorch "
2024-09-11 11:44:26 -07:00
# Default kernel backends
if self . attention_backend is None :
self . attention_backend = " flashinfer "
if self . sampling_backend is None :
self . sampling_backend = " flashinfer "
2024-10-20 19:47:14 -07:00
if self . enable_overlap_schedule :
logger . warning (
" Overlap scheduler mode is enabled. This is an experimental feature. "
" Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
" and embedding APIs are not supported and will lead to wrong results. "
" The NaN detection is also disabled. "
)
self . disable_penalizer = True
self . disable_nan_detection = True
2024-09-10 17:11:16 -07:00
# Model-specific patches
if " Alibaba-NLP/gte-Qwen2-1.5B-instruct " == self . model_path :
logger . info (
" Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True "
)
self . trust_remote_code = False
if " gemma-2 " in self . model_path . lower ( ) :
logger . info ( " When using sliding window in gemma-2, turn on flashinfer. " )
self . attention_backend = " flashinfer "
2024-01-08 04:37:50 +00:00
@staticmethod
def add_cli_args ( parser : argparse . ArgumentParser ) :
2024-11-08 02:19:41 -08:00
# Model and port args
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --model-path " ,
type = str ,
help = " The path of the model weights. This can be a local folder or a Hugging Face repo ID. " ,
required = True ,
)
parser . add_argument (
" --tokenizer-path " ,
type = str ,
default = ServerArgs . tokenizer_path ,
help = " The path of the tokenizer. " ,
)
2024-05-14 07:57:00 +08:00
parser . add_argument (
" --host " , type = str , default = ServerArgs . host , help = " The host of the server. "
)
parser . add_argument (
" --port " , type = int , default = ServerArgs . port , help = " The port of the server. "
)
2024-06-27 23:30:39 -07:00
parser . add_argument (
" --tokenizer-mode " ,
type = str ,
default = ServerArgs . tokenizer_mode ,
choices = [ " auto " , " slow " ] ,
help = " Tokenizer mode. ' auto ' will use the fast "
" tokenizer if available, and ' slow ' will "
" always use the slow tokenizer. " ,
)
2024-08-10 03:14:13 +08:00
parser . add_argument (
" --skip-tokenizer-init " ,
action = " store_true " ,
help = " If set, skip init tokenizer and pass input_ids in generate request " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --load-format " ,
type = str ,
default = ServerArgs . load_format ,
choices = [ " auto " , " pt " , " safetensors " , " npcache " , " dummy " ] ,
help = " The format of the model weights to load. "
' " auto " will try to load the weights in the safetensors format '
" and fall back to the pytorch bin format if safetensors format "
" is not available. "
' " pt " will load the weights in the pytorch bin format. '
' " safetensors " will load the weights in the safetensors format. '
' " npcache " will load the weights in pytorch format and store '
" a numpy cache to speed up the loading. "
' " dummy " will initialize the weights with random values, '
" which is mainly for profiling. " ,
)
2024-10-12 17:53:23 -07:00
parser . add_argument (
" --trust-remote-code " ,
action = " store_true " ,
help = " Whether or not to allow for custom models defined on the Hub in their own modeling files. " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-06-27 23:30:39 -07:00
" --dtype " ,
2024-01-18 23:43:09 -08:00
type = str ,
2024-06-27 23:30:39 -07:00
default = ServerArgs . dtype ,
2024-07-05 10:06:17 -07:00
choices = [ " auto " , " half " , " float16 " , " bfloat16 " , " float " , " float32 " ] ,
help = " Data type for model weights and activations. \n \n "
2024-06-27 23:30:39 -07:00
' * " auto " will use FP16 precision for FP32 and FP16 models, and '
2024-07-05 10:06:17 -07:00
" BF16 precision for BF16 models. \n "
2024-06-27 23:30:39 -07:00
' * " half " for FP16. Recommended for AWQ quantization. \n '
' * " float16 " is the same as " half " . \n '
' * " bfloat16 " for a balance between precision and range. \n '
' * " float " is shorthand for FP32 precision. \n '
2024-07-05 10:06:17 -07:00
' * " float32 " for FP32 precision. ' ,
)
2024-08-26 08:38:11 +08:00
parser . add_argument (
" --kv-cache-dtype " ,
type = str ,
default = ServerArgs . kv_cache_dtype ,
choices = [ " auto " , " fp8_e5m2 " ] ,
help = ' Data type for kv cache storage. " auto " will use model data type. " fp8_e5m2 " is supported for CUDA 11.8+. ' ,
)
2024-05-21 11:46:35 -07:00
parser . add_argument (
" --quantization " ,
type = str ,
default = ServerArgs . quantization ,
2024-07-19 09:54:01 -07:00
choices = [
" awq " ,
" fp8 " ,
" gptq " ,
" marlin " ,
" gptq_marlin " ,
2024-07-30 02:04:51 -07:00
" awq_marlin " ,
2024-07-19 09:54:01 -07:00
" bitsandbytes " ,
] ,
2024-05-21 11:46:35 -07:00
help = " The quantization method. " ,
)
2024-10-12 17:53:23 -07:00
parser . add_argument (
" --context-length " ,
type = int ,
default = ServerArgs . context_length ,
help = " The model ' s maximum context length. Defaults to None (will use the value from the model ' s config.json instead). " ,
)
parser . add_argument (
" --device " ,
type = str ,
default = " cuda " ,
choices = [ " cuda " , " xpu " ] ,
help = " The device type. " ,
)
2024-08-02 08:13:51 +08:00
parser . add_argument (
" --served-model-name " ,
type = str ,
default = ServerArgs . served_model_name ,
help = " Override the model name returned by the v1/models endpoint in OpenAI API server. " ,
)
2024-06-27 23:30:39 -07:00
parser . add_argument (
" --chat-template " ,
type = str ,
default = ServerArgs . chat_template ,
help = " The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. " ,
)
2024-09-10 17:11:16 -07:00
parser . add_argument (
" --is-embedding " ,
action = " store_true " ,
help = " Whether to use a CausalLM as an embedding model. " ,
)
2024-11-08 02:19:41 -08:00
# Memory and scheduling
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --mem-fraction-static " ,
type = float ,
default = ServerArgs . mem_fraction_static ,
2024-01-17 18:37:02 -08:00
help = " The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-26 12:51:45 -07:00
parser . add_argument (
" --max-running-requests " ,
type = int ,
default = ServerArgs . max_running_requests ,
help = " The maximum number of running requests. " ,
)
2024-07-30 13:33:55 -07:00
parser . add_argument (
" --max-total-tokens " ,
type = int ,
default = ServerArgs . max_total_tokens ,
2024-09-10 17:11:16 -07:00
help = " The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
" This option is typically used for development and debugging purposes. " ,
2024-07-30 13:33:55 -07:00
)
2024-08-11 01:18:52 -07:00
parser . add_argument (
" --chunked-prefill-size " ,
type = int ,
default = ServerArgs . chunked_prefill_size ,
help = " The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill " ,
)
parser . add_argument (
" --max-prefill-tokens " ,
type = int ,
default = ServerArgs . max_prefill_tokens ,
help = " The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model ' s maximum context length. " ,
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-07-29 23:04:48 -07:00
" --schedule-policy " ,
2024-01-08 04:37:50 +00:00
type = str ,
2024-07-29 23:04:48 -07:00
default = ServerArgs . schedule_policy ,
2024-05-13 12:47:13 +08:00
choices = [ " lpm " , " random " , " fcfs " , " dfs-weight " ] ,
2024-07-29 23:04:48 -07:00
help = " The scheduling policy of the requests. " ,
2024-01-08 04:37:50 +00:00
)
2024-01-17 18:37:02 -08:00
parser . add_argument (
" --schedule-conservativeness " ,
type = float ,
default = ServerArgs . schedule_conservativeness ,
2024-01-20 03:01:15 +08:00
help = " How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently. " ,
2024-01-17 18:37:02 -08:00
)
2024-11-08 02:19:41 -08:00
# Other runtime options
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-08-06 02:12:53 +08:00
" --tensor-parallel-size " ,
2024-05-11 20:55:00 -07:00
" --tp-size " ,
2024-01-08 04:37:50 +00:00
type = int ,
2024-05-11 20:55:00 -07:00
default = ServerArgs . tp_size ,
2024-06-25 01:16:20 -07:00
help = " The tensor parallelism size. " ,
2024-03-11 04:43:39 -07:00
)
2024-01-17 02:54:41 -08:00
parser . add_argument (
" --stream-interval " ,
type = int ,
2024-01-17 16:38:20 -08:00
default = ServerArgs . stream_interval ,
2024-01-29 17:05:42 -08:00
help = " The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher " ,
2024-01-17 02:54:41 -08:00
)
2024-05-11 20:55:00 -07:00
parser . add_argument (
" --random-seed " ,
type = int ,
default = ServerArgs . random_seed ,
2024-06-25 01:16:20 -07:00
help = " The random seed. " ,
2024-05-11 20:55:00 -07:00
)
2024-09-16 13:29:18 -07:00
parser . add_argument (
" --constrained-json-whitespace-pattern " ,
type = str ,
default = ServerArgs . constrained_json_whitespace_pattern ,
help = r " Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [ \ n \ t ]* " ,
)
2024-11-08 02:19:41 -08:00
parser . add_argument (
" --watchdog-timeout " ,
type = float ,
default = ServerArgs . watchdog_timeout ,
help = " Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging. " ,
)
# Logging
2024-01-08 04:37:50 +00:00
parser . add_argument (
" --log-level " ,
type = str ,
default = ServerArgs . log_level ,
2024-06-25 01:16:20 -07:00
help = " The logging level of all loggers. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-12 06:41:32 -07:00
parser . add_argument (
2024-06-25 01:16:20 -07:00
" --log-level-http " ,
type = str ,
default = ServerArgs . log_level_http ,
help = " The logging level of HTTP server. If not set, reuse --log-level by default. " ,
2024-05-12 06:41:32 -07:00
)
2024-01-08 04:37:50 +00:00
parser . add_argument (
2024-06-25 01:16:20 -07:00
" --log-requests " ,
2024-01-08 04:37:50 +00:00
action = " store_true " ,
2024-06-25 01:16:20 -07:00
help = " Log the inputs and outputs of all requests. " ,
2024-01-08 04:37:50 +00:00
)
2024-05-11 20:55:00 -07:00
parser . add_argument (
" --show-time-cost " ,
action = " store_true " ,
2024-07-19 09:54:01 -07:00
help = " Show time cost of custom marks. " ,
2024-05-11 20:55:00 -07:00
)
2024-11-06 12:42:53 +08:00
parser . add_argument (
" --enable-metrics " ,
action = " store_true " ,
help = " Enable log prometheus metrics. " ,
)
2024-11-08 02:19:41 -08:00
parser . add_argument (
" --decode-log-interval " ,
type = int ,
default = ServerArgs . decode_log_interval ,
help = " The log interval of decode batch " ,
)
2024-11-06 12:42:53 +08:00
2024-11-08 02:19:41 -08:00
# API related
2024-04-09 23:27:31 +08:00
parser . add_argument (
" --api-key " ,
type = str ,
default = ServerArgs . api_key ,
2024-08-04 13:35:44 -07:00
help = " Set API key of the server. It is also used in the OpenAI API compatible server. " ,
2024-04-09 23:27:31 +08:00
)
2024-07-30 04:07:18 +08:00
parser . add_argument (
" --file-storage-pth " ,
type = str ,
default = ServerArgs . file_storage_pth ,
help = " The path of the file storage in backend. " ,
)
2024-10-16 20:49:22 +02:00
parser . add_argument (
" --enable-cache-report " ,
action = " store_true " ,
help = " Return number of cached tokens in usage.prompt_tokens_details for each openai request. " ,
)
2024-05-11 20:55:00 -07:00
2024-05-27 21:24:10 -07:00
# Data parallelism
parser . add_argument (
2024-08-06 02:12:53 +08:00
" --data-parallel-size " ,
2024-05-27 21:24:10 -07:00
" --dp-size " ,
type = int ,
default = ServerArgs . dp_size ,
2024-06-25 01:16:20 -07:00
help = " The data parallelism size. " ,
2024-05-27 21:24:10 -07:00
)
parser . add_argument (
" --load-balance-method " ,
type = str ,
default = ServerArgs . load_balance_method ,
2024-06-25 01:16:20 -07:00
help = " The load balancing strategy for data parallelism. " ,
2024-05-27 21:24:10 -07:00
choices = [
" round_robin " ,
" shortest_queue " ,
] ,
)
2024-11-08 02:19:41 -08:00
# Multi-node distributed serving
2024-06-17 20:41:24 -07:00
parser . add_argument (
2024-09-29 02:36:12 -07:00
" --dist-init-addr " ,
" --nccl-init-addr " , # For backward compatbility. This will be removed in the future.
2024-06-17 20:41:24 -07:00
type = str ,
2024-09-29 02:36:12 -07:00
help = " The host address for initializing distributed backend (e.g., `192.168.0.2:25000`). " ,
2024-06-17 20:41:24 -07:00
)
parser . add_argument (
2024-07-29 03:32:58 -07:00
" --nnodes " , type = int , default = ServerArgs . nnodes , help = " The number of nodes. "
2024-06-17 20:41:24 -07:00
)
2024-09-29 02:36:12 -07:00
parser . add_argument (
" --node-rank " , type = int , default = ServerArgs . node_rank , help = " The node rank. "
)
2024-06-17 20:41:24 -07:00
2024-09-09 04:14:11 -07:00
# Model override args
parser . add_argument (
" --json-model-override-args " ,
type = str ,
help = " A dictionary in JSON string format used to override default model configurations. " ,
default = ServerArgs . json_model_override_args ,
)
2024-10-14 02:00:41 -07:00
# Double Sparsity
parser . add_argument (
" --enable-double-sparsity " ,
action = " store_true " ,
help = " Enable double sparsity attention " ,
)
parser . add_argument (
" --ds-channel-config-path " ,
type = str ,
default = ServerArgs . ds_channel_config_path ,
help = " The path of the double sparsity channel config " ,
)
parser . add_argument (
" --ds-heavy-channel-num " ,
type = int ,
default = ServerArgs . ds_heavy_channel_num ,
help = " The number of heavy channels in double sparsity attention " ,
)
parser . add_argument (
" --ds-heavy-token-num " ,
type = int ,
default = ServerArgs . ds_heavy_token_num ,
help = " The number of heavy tokens in double sparsity attention " ,
)
parser . add_argument (
" --ds-heavy-channel-type " ,
type = str ,
default = ServerArgs . ds_heavy_channel_type ,
help = " The type of heavy channels in double sparsity attention " ,
)
parser . add_argument (
" --ds-sparse-decode-threshold " ,
type = int ,
default = ServerArgs . ds_sparse_decode_threshold ,
help = " The type of heavy channels in double sparsity attention " ,
)
2024-10-12 17:53:23 -07:00
# LoRA
parser . add_argument (
" --lora-paths " ,
type = str ,
nargs = " * " ,
default = None ,
action = LoRAPathAction ,
help = " The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name} = {path} " ,
)
parser . add_argument (
" --max-loras-per-batch " ,
type = int ,
default = 8 ,
help = " Maximum number of adapters for a running batch, include base-only request " ,
)
# Kernel backend
2024-09-10 17:11:16 -07:00
parser . add_argument (
" --attention-backend " ,
type = str ,
choices = [ " flashinfer " , " triton " ] ,
default = ServerArgs . attention_backend ,
help = " Choose the kernels for attention layers. " ,
)
parser . add_argument (
" --sampling-backend " ,
type = str ,
choices = [ " flashinfer " , " pytorch " ] ,
default = ServerArgs . sampling_backend ,
help = " Choose the kernels for sampling layers. " ,
)
2024-10-26 06:47:02 +09:00
parser . add_argument (
" --grammar-backend " ,
type = str ,
choices = [ " xgrammar " , " outlines " ] ,
default = ServerArgs . grammar_backend ,
2024-11-12 21:17:38 -08:00
help = " Choose the backend for grammar-guided decoding. " ,
2024-10-26 06:47:02 +09:00
)
2024-10-12 17:53:23 -07:00
# Optimization/debug options
2024-04-09 23:27:31 +08:00
parser . add_argument (
2024-07-02 02:25:07 -07:00
" --disable-flashinfer " ,
2024-04-09 23:27:31 +08:00
action = " store_true " ,
2024-09-10 17:11:16 -07:00
help = " Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use ' --attention-backend triton ' instead. " ,
2024-07-27 20:18:56 -07:00
)
parser . add_argument (
" --disable-flashinfer-sampling " ,
action = " store_true " ,
2024-09-10 17:11:16 -07:00
help = " Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use ' --sampling-backend pytorch ' instead. " ,
2024-04-09 23:27:31 +08:00
)
2024-03-11 20:06:52 +08:00
parser . add_argument (
2024-05-11 20:55:00 -07:00
" --disable-radix-cache " ,
2024-03-11 20:06:52 +08:00
action = " store_true " ,
2024-07-19 09:54:01 -07:00
help = " Disable RadixAttention for prefix caching. " ,
2024-03-11 20:06:52 +08:00
)
2024-01-25 01:16:25 +08:00
parser . add_argument (
2024-11-12 21:17:38 -08:00
" --disable-jump-forward " ,
2024-01-25 01:16:25 +08:00
action = " store_true " ,
2024-11-12 21:17:38 -08:00
help = " Disable jump-forward for grammar-guided decoding. " ,
2024-01-25 01:16:25 +08:00
)
2024-07-12 13:00:03 -07:00
parser . add_argument (
" --disable-cuda-graph " ,
action = " store_true " ,
help = " Disable cuda graph. " ,
)
2024-02-08 00:50:12 +08:00
parser . add_argument (
2024-08-20 22:35:05 -07:00
" --disable-cuda-graph-padding " ,
action = " store_true " ,
help = " Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. " ,
)
parser . add_argument (
2024-02-08 00:50:12 +08:00
" --disable-disk-cache " ,
action = " store_true " ,
help = " Disable disk cache to avoid possible crashes related to file system or high concurrency. " ,
)
2024-08-24 08:02:23 -07:00
parser . add_argument (
" --disable-custom-all-reduce " ,
action = " store_true " ,
help = " Disable the custom all-reduce kernel and fall back to NCCL. " ,
)
2024-09-17 19:42:48 +08:00
parser . add_argument (
" --disable-mla " ,
action = " store_true " ,
help = " Disable Multi-head Latent Attention (MLA) for DeepSeek-V2. " ,
)
2024-10-12 17:53:23 -07:00
parser . add_argument (
" --disable-penalizer " ,
action = " store_true " ,
2024-10-18 20:21:24 -07:00
help = " Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests. " ,
)
parser . add_argument (
" --disable-nan-detection " ,
action = " store_true " ,
help = " Disable the NaN detection for better performance. " ,
2024-10-12 17:53:23 -07:00
)
2024-10-16 01:33:20 -07:00
parser . add_argument (
" --enable-overlap-schedule " ,
action = " store_true " ,
help = " Overlap the CPU scheduler with GPU model worker. Experimental feature. " ,
)
2024-08-16 02:13:00 -07:00
parser . add_argument (
" --enable-mixed-chunk " ,
action = " store_true " ,
2024-08-25 14:46:34 -07:00
help = " Enabling mixing prefill and decode in a batch when using chunked prefill. " ,
2024-08-16 02:13:00 -07:00
)
2024-07-20 18:34:37 -07:00
parser . add_argument (
" --enable-torch-compile " ,
action = " store_true " ,
2024-09-09 05:32:41 -07:00
help = " Optimize the model with torch.compile. Experimental feature. " ,
)
2024-09-17 15:52:08 +08:00
parser . add_argument (
2024-10-26 16:39:41 -07:00
" --torch-compile-max-bs " ,
2024-09-17 15:52:08 +08:00
type = int ,
2024-10-26 16:39:41 -07:00
default = ServerArgs . torch_compile_max_bs ,
2024-09-17 15:52:08 +08:00
help = " Set the maximum batch size when using torch compile. " ,
)
2024-10-26 15:09:33 -07:00
parser . add_argument (
2024-10-26 16:39:41 -07:00
" --cuda-graph-max-bs " ,
2024-10-26 15:09:33 -07:00
type = int ,
2024-10-26 16:39:41 -07:00
default = ServerArgs . cuda_graph_max_bs ,
2024-10-26 15:09:33 -07:00
help = " Set the maximum batch size for cuda graph. " ,
)
2024-09-09 05:32:41 -07:00
parser . add_argument (
" --torchao-config " ,
type = str ,
default = ServerArgs . torchao_config ,
help = " Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo " ,
2024-07-20 18:34:37 -07:00
)
2024-07-03 02:07:34 -07:00
parser . add_argument (
2024-07-27 20:18:56 -07:00
" --enable-p2p-check " ,
2024-07-03 02:07:34 -07:00
action = " store_true " ,
2024-07-27 20:18:56 -07:00
help = " Enable P2P check for GPU access, otherwise the p2p access is allowed by default. " ,
2024-07-03 02:07:34 -07:00
)
2024-07-06 23:34:10 -07:00
parser . add_argument (
2024-08-24 08:02:23 -07:00
" --triton-attention-reduce-in-fp32 " ,
2024-07-06 23:34:10 -07:00
action = " store_true " ,
2024-07-27 20:18:56 -07:00
help = " Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16. "
2024-08-21 19:24:36 -07:00
" This only affects Triton attention kernels. " ,
2024-07-06 23:34:10 -07:00
)
2024-10-12 21:35:30 -07:00
parser . add_argument (
" --num-continuous-decode-steps " ,
type = int ,
default = ServerArgs . num_continuous_decode_steps ,
help = " Run multiple continuous decoding steps to reduce scheduling overhead. "
" This can potentially increase throughput but may also increase time-to-first-token latency. "
" The default value is 1, meaning only run one decoding step at a time. " ,
)
2024-11-08 02:19:41 -08:00
parser . add_argument (
" --delete-ckpt-after-loading " ,
action = " store_true " ,
help = " Delete the model checkpoint after loading the model. " ,
)
2024-09-12 16:46:14 -07:00
2024-01-08 04:37:50 +00:00
@classmethod
def from_cli_args ( cls , args : argparse . Namespace ) :
2024-08-06 02:12:53 +08:00
args . tp_size = args . tensor_parallel_size
args . dp_size = args . data_parallel_size
2024-01-08 04:37:50 +00:00
attrs = [ attr . name for attr in dataclasses . fields ( cls ) ]
return cls ( * * { attr : getattr ( args , attr ) for attr in attrs } )
def url ( self ) :
2024-09-30 02:02:40 +08:00
if is_ipv6 ( self . host ) :
return f " http://[ { self . host } ]: { self . port } "
else :
return f " http:// { self . host } : { self . port } "
2024-01-08 04:37:50 +00:00
2024-07-27 19:03:40 -07:00
def check_server_args ( self ) :
assert (
self . tp_size % self . nnodes == 0
) , " tp_size must be divisible by number of nodes "
assert not (
2024-10-11 07:22:48 -07:00
self . dp_size > 1 and self . nnodes != 1
2024-07-27 19:03:40 -07:00
) , " multi-node data parallel is not supported "
2024-09-12 16:46:14 -07:00
assert (
self . max_loras_per_batch > 0
# FIXME
and ( self . lora_paths is None or self . disable_cuda_graph )
and ( self . lora_paths is None or self . disable_radix_cache )
) , " compatibility of lora and cuda graph and radix attention is in progress "
2024-07-27 19:03:40 -07:00
2024-09-30 10:06:08 -07:00
if isinstance ( self . lora_paths , list ) :
lora_paths = self . lora_paths
self . lora_paths = { }
for lora_path in lora_paths :
if " = " in lora_path :
name , path = lora_path . split ( " = " , 1 )
self . lora_paths [ name ] = path
else :
self . lora_paths [ lora_path ] = lora_path
2024-01-08 04:37:50 +00:00
2024-09-09 04:14:11 -07:00
def prepare_server_args ( argv : List [ str ] ) - > ServerArgs :
2024-09-09 02:14:25 -07:00
"""
Prepare the server arguments from the command line arguments .
Args :
args : The command line arguments . Typically , it should be ` sys . argv [ 1 : ] `
to ensure compatibility with ` parse_args ` when no arguments are passed .
Returns :
The server arguments .
"""
parser = argparse . ArgumentParser ( )
ServerArgs . add_cli_args ( parser )
2024-09-09 04:14:11 -07:00
raw_args = parser . parse_args ( argv )
2024-09-09 02:14:25 -07:00
server_args = ServerArgs . from_cli_args ( raw_args )
return server_args
2024-01-08 04:37:50 +00:00
@dataclasses.dataclass
class PortArgs :
2024-10-04 00:45:52 -07:00
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_ipc_name : str
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_ipc_name : str
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name : str
2024-10-03 18:29:49 -07:00
2024-10-11 07:22:48 -07:00
# The port for nccl initialization (torch.dist)
nccl_port : int
2024-09-18 00:56:06 -07:00
2024-10-11 02:17:47 -07:00
@staticmethod
def init_new ( server_args ) - > " PortArgs " :
2024-10-25 23:43:24 -07:00
port = server_args . port + 42
2024-10-04 00:45:52 -07:00
while True :
if is_port_available ( port ) :
break
2024-10-25 23:43:24 -07:00
port + = 42
2024-10-04 00:45:52 -07:00
return PortArgs (
tokenizer_ipc_name = tempfile . NamedTemporaryFile ( delete = False ) . name ,
scheduler_input_ipc_name = tempfile . NamedTemporaryFile ( delete = False ) . name ,
detokenizer_ipc_name = tempfile . NamedTemporaryFile ( delete = False ) . name ,
2024-10-11 07:22:48 -07:00
nccl_port = port ,
2024-10-04 00:45:52 -07:00
)
2024-09-18 00:56:06 -07:00
class LoRAPathAction ( argparse . Action ) :
def __call__ ( self , parser , namespace , values , option_string = None ) :
setattr ( namespace , self . dest , { } )
for lora_path in values :
if " = " in lora_path :
name , path = lora_path . split ( " = " , 1 )
getattr ( namespace , self . dest ) [ name ] = path
else :
getattr ( namespace , self . dest ) [ lora_path ] = lora_path