# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The arguments of the server.""" import argparse import dataclasses import logging import random import tempfile from typing import List, Optional from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, is_ipv6, is_port_available, ) logger = logging.getLogger(__name__) @dataclasses.dataclass class ServerArgs: # Model and tokenizer model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False load_format: str = "auto" trust_remote_code: bool = True dtype: str = "auto" kv_cache_dtype: str = "auto" quantization: Optional[str] = None context_length: Optional[int] = None device: str = "cuda" served_model_name: Optional[str] = None chat_template: Optional[str] = None is_embedding: bool = False # Port host: str = "127.0.0.1" port: int = 30000 # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: int = 8192 max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 # Other runtime options tp_size: int = 1 stream_interval: int = 1 random_seed: Optional[int] = None constrained_json_whitespace_pattern: Optional[str] = None watchdog_timeout: float = 300 download_dir: Optional[str] = None base_gpu_id: int = 0 # Logging log_level: str = "info" log_level_http: Optional[str] = None log_requests: bool = False show_time_cost: bool = False enable_metrics: bool = False decode_log_interval: int = 40 # API related api_key: Optional[str] = None file_storage_pth: str = "SGLang_storage" enable_cache_report: bool = False # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" # Multi-node distributed serving dist_init_addr: Optional[str] = None nnodes: int = 1 node_rank: int = 0 # Model override args in JSON json_model_override_args: str = "{}" # Double Sparsity enable_double_sparsity: bool = False ds_channel_config_path: str = None ds_heavy_channel_num: int = 32 ds_heavy_token_num: int = 256 ds_heavy_channel_type: str = "qk" ds_sparse_decode_threshold: int = 4096 # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 # Kernel backend attention_backend: Optional[str] = None sampling_backend: Optional[str] = None grammar_backend: Optional[str] = "outlines" # Optimization/debug options disable_radix_cache: bool = False disable_jump_forward: bool = False disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False disable_custom_all_reduce: bool = False disable_mla: bool = False disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int = 160 torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False def __post_init__(self): # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path if self.served_model_name is None: self.served_model_name = self.model_path if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0: # Disable chunked prefill self.chunked_prefill_size = None if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) # Mem fraction depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: self.mem_fraction_static = 0.79 elif self.tp_size >= 8: self.mem_fraction_static = 0.82 elif self.tp_size >= 4: self.mem_fraction_static = 0.85 elif self.tp_size >= 2: self.mem_fraction_static = 0.87 else: self.mem_fraction_static = 0.88 # Adjust for GPUs with small memory capacities if is_hip(): gpu_mem = get_amdgpu_memory_capacity() else: gpu_mem = get_nvgpu_memory_capacity() if gpu_mem < 25000: self.chunked_prefill_size //= 4 # make it 2048 self.cuda_graph_max_bs = 4 logger.info("Automatically adjust --chunked-prefill-size for small GPUs.") # Choose kernel backends if not is_flashinfer_available(): self.attention_backend = "triton" self.sampling_backend = "pytorch" if self.attention_backend is None: self.attention_backend = "flashinfer" if self.sampling_backend is None: self.sampling_backend = "flashinfer" # Others if self.enable_dp_attention: self.dp_size = self.tp_size self.chunked_prefill_size = self.chunked_prefill_size // 2 self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.disable_overlap_schedule = True logger.info( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " "Overlap schedule is disabled." ) @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args parser.add_argument( "--model-path", type=str, help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", required=True, ) parser.add_argument( "--tokenizer-path", type=str, default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) parser.add_argument( "--host", type=str, default=ServerArgs.host, help="The host of the server." ) parser.add_argument( "--port", type=int, default=ServerArgs.port, help="The port of the server." ) parser.add_argument( "--tokenizer-mode", type=str, default=ServerArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " "tokenizer if available, and 'slow' will " "always use the slow tokenizer.", ) parser.add_argument( "--skip-tokenizer-init", action="store_true", help="If set, skip init tokenizer and pass input_ids in generate request", ) parser.add_argument( "--load-format", type=str, default=ServerArgs.load_format, choices=["auto", "pt", "safetensors", "npcache", "dummy"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " "is not available. " '"pt" will load the weights in the pytorch bin format. ' '"safetensors" will load the weights in the safetensors format. ' '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' "which is mainly for profiling.", ) parser.add_argument( "--trust-remote-code", action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) parser.add_argument( "--dtype", type=str, default=ServerArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" '* "auto" will use FP16 precision for FP32 and FP16 models, and ' "BF16 precision for BF16 models.\n" '* "half" for FP16. Recommended for AWQ quantization.\n' '* "float16" is the same as "half".\n' '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) parser.add_argument( "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, choices=["auto", "fp8_e5m2"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', ) parser.add_argument( "--quantization", type=str, default=ServerArgs.quantization, choices=[ "awq", "fp8", "gptq", "marlin", "gptq_marlin", "awq_marlin", "bitsandbytes", ], help="The quantization method.", ) parser.add_argument( "--context-length", type=int, default=ServerArgs.context_length, help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", ) parser.add_argument( "--device", type=str, default="cuda", choices=["cuda", "xpu", "hpu"], help="The device type.", ) parser.add_argument( "--served-model-name", type=str, default=ServerArgs.served_model_name, help="Override the model name returned by the v1/models endpoint in OpenAI API server.", ) parser.add_argument( "--chat-template", type=str, default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) parser.add_argument( "--is-embedding", action="store_true", help="Whether to use a CausalLM as an embedding model.", ) # Memory and scheduling parser.add_argument( "--mem-fraction-static", type=float, default=ServerArgs.mem_fraction_static, help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) parser.add_argument( "--max-running-requests", type=int, default=ServerArgs.max_running_requests, help="The maximum number of running requests.", ) parser.add_argument( "--max-total-tokens", type=int, default=ServerArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", type=int, default=ServerArgs.chunked_prefill_size, help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill", ) parser.add_argument( "--max-prefill-tokens", type=int, default=ServerArgs.max_prefill_tokens, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) parser.add_argument( "--schedule-policy", type=str, default=ServerArgs.schedule_policy, choices=["lpm", "random", "fcfs", "dfs-weight"], help="The scheduling policy of the requests.", ) parser.add_argument( "--schedule-conservativeness", type=float, default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( "--cpu-offload-gb", type=int, default=ServerArgs.cpu_offload_gb, help="How many GBs of RAM to reserve for CPU offloading", ) # Other runtime options parser.add_argument( "--tensor-parallel-size", "--tp-size", type=int, default=ServerArgs.tp_size, help="The tensor parallelism size.", ) parser.add_argument( "--stream-interval", type=int, default=ServerArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) parser.add_argument( "--random-seed", type=int, default=ServerArgs.random_seed, help="The random seed.", ) parser.add_argument( "--constrained-json-whitespace-pattern", type=str, default=ServerArgs.constrained_json_whitespace_pattern, help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", ) parser.add_argument( "--watchdog-timeout", type=float, default=ServerArgs.watchdog_timeout, help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", ) parser.add_argument( "--download-dir", type=str, default=ServerArgs.download_dir, help="Model download directory.", ) parser.add_argument( "--base-gpu-id", type=int, default=ServerArgs.base_gpu_id, help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.", ) # Logging parser.add_argument( "--log-level", type=str, default=ServerArgs.log_level, help="The logging level of all loggers.", ) parser.add_argument( "--log-level-http", type=str, default=ServerArgs.log_level_http, help="The logging level of HTTP server. If not set, reuse --log-level by default.", ) parser.add_argument( "--log-requests", action="store_true", help="Log the inputs and outputs of all requests.", ) parser.add_argument( "--show-time-cost", action="store_true", help="Show time cost of custom marks.", ) parser.add_argument( "--enable-metrics", action="store_true", help="Enable log prometheus metrics.", ) parser.add_argument( "--decode-log-interval", type=int, default=ServerArgs.decode_log_interval, help="The log interval of decode batch", ) # API related parser.add_argument( "--api-key", type=str, default=ServerArgs.api_key, help="Set API key of the server. It is also used in the OpenAI API compatible server.", ) parser.add_argument( "--file-storage-pth", type=str, default=ServerArgs.file_storage_pth, help="The path of the file storage in backend.", ) parser.add_argument( "--enable-cache-report", action="store_true", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", ) # Data parallelism parser.add_argument( "--data-parallel-size", "--dp-size", type=int, default=ServerArgs.dp_size, help="The data parallelism size.", ) parser.add_argument( "--load-balance-method", type=str, default=ServerArgs.load_balance_method, help="The load balancing strategy for data parallelism.", choices=[ "round_robin", "shortest_queue", ], ) # Multi-node distributed serving parser.add_argument( "--dist-init-addr", "--nccl-init-addr", # For backward compatbility. This will be removed in the future. type=str, help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).", ) parser.add_argument( "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." ) parser.add_argument( "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank." ) # Model override args parser.add_argument( "--json-model-override-args", type=str, help="A dictionary in JSON string format used to override default model configurations.", default=ServerArgs.json_model_override_args, ) # Double Sparsity parser.add_argument( "--enable-double-sparsity", action="store_true", help="Enable double sparsity attention", ) parser.add_argument( "--ds-channel-config-path", type=str, default=ServerArgs.ds_channel_config_path, help="The path of the double sparsity channel config", ) parser.add_argument( "--ds-heavy-channel-num", type=int, default=ServerArgs.ds_heavy_channel_num, help="The number of heavy channels in double sparsity attention", ) parser.add_argument( "--ds-heavy-token-num", type=int, default=ServerArgs.ds_heavy_token_num, help="The number of heavy tokens in double sparsity attention", ) parser.add_argument( "--ds-heavy-channel-type", type=str, default=ServerArgs.ds_heavy_channel_type, help="The type of heavy channels in double sparsity attention", ) parser.add_argument( "--ds-sparse-decode-threshold", type=int, default=ServerArgs.ds_sparse_decode_threshold, help="The type of heavy channels in double sparsity attention", ) # LoRA parser.add_argument( "--lora-paths", type=str, nargs="*", default=None, action=LoRAPathAction, help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", ) parser.add_argument( "--max-loras-per-batch", type=int, default=8, help="Maximum number of adapters for a running batch, include base-only request", ) # Kernel backend parser.add_argument( "--attention-backend", type=str, choices=["flashinfer", "triton"], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) parser.add_argument( "--sampling-backend", type=str, choices=["flashinfer", "pytorch"], default=ServerArgs.sampling_backend, help="Choose the kernels for sampling layers.", ) parser.add_argument( "--grammar-backend", type=str, choices=["xgrammar", "outlines"], default=ServerArgs.grammar_backend, help="Choose the backend for grammar-guided decoding.", ) # Optimization/debug options parser.add_argument( "--disable-radix-cache", action="store_true", help="Disable RadixAttention for prefix caching.", ) parser.add_argument( "--disable-jump-forward", action="store_true", help="Disable jump-forward for grammar-guided decoding.", ) parser.add_argument( "--disable-cuda-graph", action="store_true", help="Disable cuda graph.", ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", ) parser.add_argument( "--disable-disk-cache", action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) parser.add_argument( "--disable-custom-all-reduce", action="store_true", help="Disable the custom all-reduce kernel and fall back to NCCL.", ) parser.add_argument( "--disable-mla", action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) parser.add_argument( "--disable-nan-detection", action="store_true", help="Disable the NaN detection for better performance.", ) parser.add_argument( "--disable-overlap-schedule", action="store_true", help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.", ) parser.add_argument( "--enable-mixed-chunk", action="store_true", help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) parser.add_argument( "--enable-dp-attention", action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) parser.add_argument( "--enable-torch-compile", action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) parser.add_argument( "--torch-compile-max-bs", type=int, default=ServerArgs.torch_compile_max_bs, help="Set the maximum batch size when using torch compile.", ) parser.add_argument( "--cuda-graph-max-bs", type=int, default=ServerArgs.cuda_graph_max_bs, help="Set the maximum batch size for cuda graph.", ) parser.add_argument( "--torchao-config", type=str, default=ServerArgs.torchao_config, help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row", ) parser.add_argument( "--enable-nan-detection", action="store_true", help="Enable the NaN detection for debugging purposes.", ) parser.add_argument( "--enable-p2p-check", action="store_true", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", ) parser.add_argument( "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) parser.add_argument( "--num-continuous-decode-steps", type=int, default=ServerArgs.num_continuous_decode_steps, help="Run multiple continuous decoding steps to reduce scheduling overhead. " "This can potentially increase throughput but may also increase time-to-first-token latency. " "The default value is 1, meaning only run one decoding step at a time.", ) parser.add_argument( "--delete-ckpt-after-loading", action="store_true", help="Delete the model checkpoint after loading the model.", ) # Deprecated arguments parser.add_argument( "--enable-overlap-schedule", action=DeprecatedAction, help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.", ) parser.add_argument( "--disable-flashinfer", action=DeprecatedAction, help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.", ) parser.add_argument( "--disable-flashinfer-sampling", action=DeprecatedAction, help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", ) @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) def url(self): if is_ipv6(self.host): return f"http://[{self.host}]:{self.port}" else: return f"http://{self.host}:{self.port}" def check_server_args(self): assert ( self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( self.dp_size > 1 and self.nnodes != 1 ), "multi-node data parallel is not supported" assert ( self.max_loras_per_batch > 0 # FIXME and (self.lora_paths is None or self.disable_cuda_graph) and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" if isinstance(self.lora_paths, list): lora_paths = self.lora_paths self.lora_paths = {} for lora_path in lora_paths: if "=" in lora_path: name, path = lora_path.split("=", 1) self.lora_paths[name] = path else: self.lora_paths[lora_path] = lora_path def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. Args: args: The command line arguments. Typically, it should be `sys.argv[1:]` to ensure compatibility with `parse_args` when no arguments are passed. Returns: The server arguments. """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) raw_args = parser.parse_args(argv) server_args = ServerArgs.from_cli_args(raw_args) return server_args @dataclasses.dataclass class PortArgs: # The ipc filename for tokenizer to receive inputs from detokenizer (zmq) tokenizer_ipc_name: str # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq) scheduler_input_ipc_name: str # The ipc filename for detokenizer to receive inputs from scheduler (zmq) detokenizer_ipc_name: str # The port for nccl initialization (torch.dist) nccl_port: int @staticmethod def init_new(server_args) -> "PortArgs": port = server_args.port + random.randint(100, 1000) while True: if is_port_available(port): break port += 42 return PortArgs( tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, nccl_port=port, ) class LoRAPathAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, {}) for lora_path in values: if "=" in lora_path: name, path = lora_path.split("=", 1) getattr(namespace, self.dest)[name] = path else: getattr(namespace, self.dest)[lora_path] = lora_path class DeprecatedAction(argparse.Action): def __init__(self, option_strings, dest, nargs=0, **kwargs): super(DeprecatedAction, self).__init__( option_strings, dest, nargs=nargs, **kwargs ) def __call__(self, parser, namespace, values, option_string=None): raise ValueError(self.help)