Files
sglang/python/sglang/srt/server_args.py

374 lines
13 KiB
Python
Raw Normal View History

2024-05-11 20:55:00 -07:00
"""The arguments of the server."""
import argparse
import dataclasses
import random
from typing import List, Optional, Union
@dataclasses.dataclass
class ServerArgs:
2024-05-11 20:55:00 -07:00
# Model and tokenizer
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
2024-06-27 23:30:39 -07:00
load_format: str = "auto"
dtype: str = "auto"
trust_remote_code: bool = True
2024-05-11 20:55:00 -07:00
context_length: Optional[int] = None
2024-05-21 11:46:35 -07:00
quantization: Optional[str] = None
2024-06-27 23:30:39 -07:00
chat_template: Optional[str] = None
2024-05-11 20:55:00 -07:00
# Port
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
# Memory and scheduling
2024-01-15 01:15:53 -08:00
mem_fraction_static: Optional[float] = None
max_prefill_tokens: Optional[int] = None
max_running_requests: Optional[int] = None
2024-07-26 17:10:07 -07:00
max_num_reqs: Optional[int] = None
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
2024-05-11 20:55:00 -07:00
# Other runtime options
tp_size: int = 1
stream_interval: int = 1
random_seed: Optional[int] = None
2024-05-11 20:55:00 -07:00
# Logging
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
2024-04-09 23:27:31 +08:00
show_time_cost: bool = False
2024-03-11 20:06:52 +08:00
2024-05-11 20:55:00 -07:00
# Other
api_key: str = ""
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
2024-05-11 20:55:00 -07:00
# Optimization/debug options
2024-07-03 23:19:33 -07:00
disable_flashinfer: bool = False
2024-05-11 20:55:00 -07:00
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
2024-07-20 18:34:37 -07:00
enable_torch_compile: bool = False
2024-07-03 23:19:33 -07:00
attention_reduce_in_fp32: bool = False
2024-07-06 23:34:10 -07:00
enable_p2p_check: bool = False
efficient_weight_load: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
2024-01-15 01:15:53 -08:00
if self.mem_fraction_static is None:
if self.tp_size >= 16:
2024-07-22 03:19:24 -07:00
self.mem_fraction_static = 0.80
elif self.tp_size >= 8:
2024-07-22 03:19:24 -07:00
self.mem_fraction_static = 0.84
2024-01-17 04:43:17 -08:00
elif self.tp_size >= 4:
2024-07-22 03:19:24 -07:00
self.mem_fraction_static = 0.86
2024-01-17 04:43:17 -08:00
elif self.tp_size >= 2:
2024-07-13 05:29:46 -07:00
self.mem_fraction_static = 0.88
2024-07-22 03:19:24 -07:00
else:
self.mem_fraction_static = 0.89
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
self.additional_ports = []
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--model-path",
type=str,
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--tokenizer-path",
type=str,
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
2024-05-14 07:57:00 +08:00
parser.add_argument(
"--host", type=str, default=ServerArgs.host, help="The host of the server."
)
parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server."
)
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="The additional ports specified for the server.",
)
2024-06-27 23:30:39 -07:00
parser.add_argument(
"--tokenizer-mode",
type=str,
default=ServerArgs.tokenizer_mode,
choices=["auto", "slow"],
help="Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--load-format",
type=str,
default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
"is not available. "
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling.",
)
parser.add_argument(
2024-06-27 23:30:39 -07:00
"--dtype",
2024-01-18 23:43:09 -08:00
type=str,
2024-06-27 23:30:39 -07:00
default=ServerArgs.dtype,
2024-07-05 10:06:17 -07:00
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="Data type for model weights and activations.\n\n"
2024-06-27 23:30:39 -07:00
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
2024-07-05 10:06:17 -07:00
"BF16 precision for BF16 models.\n"
2024-06-27 23:30:39 -07:00
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
2024-07-05 10:06:17 -07:00
'* "float32" for FP32 precision.',
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
2024-05-11 20:55:00 -07:00
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
2024-05-21 11:46:35 -07:00
parser.add_argument(
"--quantization",
type=str,
default=ServerArgs.quantization,
choices=[
"awq",
"fp8",
"gptq",
"marlin",
"gptq_marlin",
"squeezellm",
"bitsandbytes",
],
2024-05-21 11:46:35 -07:00
help="The quantization method.",
)
2024-06-27 23:30:39 -07:00
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--max-prefill-tokens",
type=int,
default=ServerArgs.max_prefill_tokens,
2024-02-06 13:27:46 -08:00
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--max-running-requests",
type=int,
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
2024-07-26 17:10:07 -07:00
parser.add_argument(
"--max-num-reqs",
type=int,
default=None,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument(
"--schedule-heuristic",
type=str,
default=ServerArgs.schedule_heuristic,
2024-05-13 12:47:13 +08:00
choices=["lpm", "random", "fcfs", "dfs-weight"],
help="The scheduling heuristic.",
)
parser.add_argument(
"--schedule-conservativeness",
type=float,
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
2024-05-11 20:55:00 -07:00
"--tp-size",
type=int,
2024-05-11 20:55:00 -07:00
default=ServerArgs.tp_size,
help="The tensor parallelism size.",
)
parser.add_argument(
"--stream-interval",
type=int,
2024-01-17 16:38:20 -08:00
default=ServerArgs.stream_interval,
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
2024-05-11 20:55:00 -07:00
parser.add_argument(
"--random-seed",
type=int,
default=ServerArgs.random_seed,
help="The random seed.",
2024-05-11 20:55:00 -07:00
)
parser.add_argument(
"--log-level",
type=str,
default=ServerArgs.log_level,
help="The logging level of all loggers.",
)
parser.add_argument(
"--log-level-http",
type=str,
default=ServerArgs.log_level_http,
help="The logging level of HTTP server. If not set, reuse --log-level by default.",
)
parser.add_argument(
"--log-requests",
action="store_true",
help="Log the inputs and outputs of all requests.",
)
2024-05-11 20:55:00 -07:00
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks.",
2024-05-11 20:55:00 -07:00
)
2024-04-09 23:27:31 +08:00
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API key of the server.",
2024-04-09 23:27:31 +08:00
)
2024-05-11 20:55:00 -07:00
# Data parallelism
parser.add_argument(
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="The data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="The load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Multi-node distributed serving args
parser.add_argument(
"--nccl-init-addr",
type=str,
2024-07-05 10:06:17 -07:00
help="The nccl init address of multi-node server.",
)
parser.add_argument(
2024-07-05 10:06:17 -07:00
"--nnodes", type=int, default=1, help="The number of nodes."
)
2024-07-05 10:06:17 -07:00
parser.add_argument("--node-rank", type=int, help="The node rank.")
2024-05-11 20:55:00 -07:00
# Optimization/debug options
2024-04-09 23:27:31 +08:00
parser.add_argument(
2024-07-02 02:25:07 -07:00
"--disable-flashinfer",
2024-04-09 23:27:31 +08:00
action="store_true",
help="Disable flashinfer inference kernels.",
2024-04-09 23:27:31 +08:00
)
2024-03-11 20:06:52 +08:00
parser.add_argument(
2024-05-11 20:55:00 -07:00
"--disable-radix-cache",
2024-03-11 20:06:52 +08:00
action="store_true",
help="Disable RadixAttention for prefix caching.",
2024-03-11 20:06:52 +08:00
)
parser.add_argument(
"--disable-regex-jump-forward",
action="store_true",
help="Disable regex jump-forward.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-disk-cache",
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
2024-07-20 18:34:37 -07:00
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile, experimental feature.",
)
2024-07-03 02:07:34 -07:00
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels",
)
2024-07-06 23:34:10 -07:00
parser.add_argument(
"--enable-p2p-check",
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--efficient-weight-load",
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self):
return f"http://{self.host}:{self.port}"
2024-05-11 20:55:00 -07:00
def print_mode_args(self):
2024-03-11 20:06:52 +08:00
return (
2024-07-02 02:25:07 -07:00
f"disable_flashinfer={self.disable_flashinfer}, "
2024-05-12 04:54:07 -07:00
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
2024-05-11 20:55:00 -07:00
f"disable_radix_cache={self.disable_radix_cache}, "
2024-03-11 20:06:52 +08:00
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
)
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
2024-07-18 02:13:54 -07:00
controller_port: int
detokenizer_port: int
2024-07-18 02:13:54 -07:00
nccl_ports: List[int]