Files
sglang/python/sglang/srt/server_args.py
Ying Sheng 8b48496aaf Revert "Revert "Add simple CPU offloading support"" (#2253)
Co-authored-by: Jani Monoses <jani.monoses@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
2024-11-28 23:58:54 -08:00

824 lines
30 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The arguments of the server."""
import argparse
import dataclasses
import logging
import random
import tempfile
from typing import List, Optional
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
is_ipv6,
is_port_available,
)
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ServerArgs:
# Model and tokenizer
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto"
trust_remote_code: bool = True
dtype: str = "auto"
kv_cache_dtype: str = "auto"
quantization: Optional[str] = None
context_length: Optional[int] = None
device: str = "cuda"
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
is_embedding: bool = False
# Port
host: str = "127.0.0.1"
port: int = 30000
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: int = 8192
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
# Other runtime options
tp_size: int = 1
stream_interval: int = 1
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300
download_dir: Optional[str] = None
base_gpu_id: int = 0
# Logging
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
show_time_cost: bool = False
enable_metrics: bool = False
decode_log_interval: int = 40
# API related
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: int = 0
# Model override args in JSON
json_model_override_args: str = "{}"
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
# Kernel backend
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines"
# Optimization/debug options
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: int = 160
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.served_model_name is None:
self.served_model_name = self.model_path
if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
# Disable chunked prefill
self.chunked_prefill_size = None
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
# Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
elif self.tp_size >= 8:
self.mem_fraction_static = 0.82
elif self.tp_size >= 4:
self.mem_fraction_static = 0.85
elif self.tp_size >= 2:
self.mem_fraction_static = 0.87
else:
self.mem_fraction_static = 0.88
# Adjust for GPUs with small memory capacities
if is_hip():
gpu_mem = get_amdgpu_memory_capacity()
else:
gpu_mem = get_nvgpu_memory_capacity()
if gpu_mem < 25000:
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
# Choose kernel backends
if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
if self.attention_backend is None:
self.attention_backend = "flashinfer"
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
# Others
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
self.disable_overlap_schedule = True
logger.info(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
parser.add_argument(
"--model-path",
type=str,
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--tokenizer-path",
type=str,
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument(
"--host", type=str, default=ServerArgs.host, help="The host of the server."
)
parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server."
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=ServerArgs.tokenizer_mode,
choices=["auto", "slow"],
help="Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument(
"--load-format",
type=str,
default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
"is not available. "
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--dtype",
type=str,
default=ServerArgs.dtype,
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="Data type for model weights and activations.\n\n"
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
"BF16 precision for BF16 models.\n"
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.',
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--quantization",
type=str,
default=ServerArgs.quantization,
choices=[
"awq",
"fp8",
"gptq",
"marlin",
"gptq_marlin",
"awq_marlin",
"bitsandbytes",
],
help="The quantization method.",
)
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu", "hpu"],
help="The device type.",
)
parser.add_argument(
"--served-model-name",
type=str,
default=ServerArgs.served_model_name,
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
)
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
# Memory and scheduling
parser.add_argument(
"--mem-fraction-static",
type=float,
default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--max-running-requests",
type=int,
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--max-total-tokens",
type=int,
default=ServerArgs.max_total_tokens,
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
"This option is typically used for development and debugging purposes.",
)
parser.add_argument(
"--chunked-prefill-size",
type=int,
default=ServerArgs.chunked_prefill_size,
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
)
parser.add_argument(
"--max-prefill-tokens",
type=int,
default=ServerArgs.max_prefill_tokens,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--schedule-policy",
type=str,
default=ServerArgs.schedule_policy,
choices=["lpm", "random", "fcfs", "dfs-weight"],
help="The scheduling policy of the requests.",
)
parser.add_argument(
"--schedule-conservativeness",
type=float,
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading",
)
# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
"--tp-size",
type=int,
default=ServerArgs.tp_size,
help="The tensor parallelism size.",
)
parser.add_argument(
"--stream-interval",
type=int,
default=ServerArgs.stream_interval,
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
parser.add_argument(
"--random-seed",
type=int,
default=ServerArgs.random_seed,
help="The random seed.",
)
parser.add_argument(
"--constrained-json-whitespace-pattern",
type=str,
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
parser.add_argument(
"--download-dir",
type=str,
default=ServerArgs.download_dir,
help="Model download directory.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
# Logging
parser.add_argument(
"--log-level",
type=str,
default=ServerArgs.log_level,
help="The logging level of all loggers.",
)
parser.add_argument(
"--log-level-http",
type=str,
default=ServerArgs.log_level_http,
help="The logging level of HTTP server. If not set, reuse --log-level by default.",
)
parser.add_argument(
"--log-requests",
action="store_true",
help="Log the inputs and outputs of all requests.",
)
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks.",
)
parser.add_argument(
"--enable-metrics",
action="store_true",
help="Enable log prometheus metrics.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)
# API related
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
)
parser.add_argument(
"--file-storage-pth",
type=str,
default=ServerArgs.file_storage_pth,
help="The path of the file storage in backend.",
)
parser.add_argument(
"--enable-cache-report",
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
# Data parallelism
parser.add_argument(
"--data-parallel-size",
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="The data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="The load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Multi-node distributed serving
parser.add_argument(
"--dist-init-addr",
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
type=str,
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
)
parser.add_argument(
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
)
parser.add_argument(
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
)
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# LoRA
parser.add_argument(
"--lora-paths",
type=str,
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
)
# Kernel backend
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,
choices=["flashinfer", "pytorch"],
default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.",
)
parser.add_argument(
"--grammar-backend",
type=str,
choices=["xgrammar", "outlines"],
default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.",
)
# Optimization/debug options
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention for prefix caching.",
)
parser.add_argument(
"--disable-jump-forward",
action="store_true",
help="Disable jump-forward for grammar-guided decoding.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
)
parser.add_argument(
"--disable-disk-cache",
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-mla",
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--disable-nan-detection",
action="store_true",
help="Disable the NaN detection for better performance.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
)
parser.add_argument(
"--enable-dp-attention",
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--torch-compile-max-bs",
type=int,
default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--cuda-graph-max-bs",
type=int,
default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph.",
)
parser.add_argument(
"--torchao-config",
type=str,
default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
)
parser.add_argument(
"--enable-nan-detection",
action="store_true",
help="Enable the NaN detection for debugging purposes.",
)
parser.add_argument(
"--enable-p2p-check",
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
parser.add_argument(
"--delete-ckpt-after-loading",
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
# Deprecated arguments
parser.add_argument(
"--enable-overlap-schedule",
action=DeprecatedAction,
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
)
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action=DeprecatedAction,
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self):
if is_ipv6(self.host):
return f"http://[{self.host}]:{self.port}"
else:
return f"http://{self.host}:{self.port}"
def check_server_args(self):
assert (
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
assert not (
self.dp_size > 1 and self.nnodes != 1
), "multi-node data parallel is not supported"
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
else:
self.lora_paths[lora_path] = lora_path
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
Args:
args: The command line arguments. Typically, it should be `sys.argv[1:]`
to ensure compatibility with `parse_args` when no arguments are passed.
Returns:
The server arguments.
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
@dataclasses.dataclass
class PortArgs:
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_ipc_name: str
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_ipc_name: str
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name: str
# The port for nccl initialization (torch.dist)
nccl_port: int
@staticmethod
def init_new(server_args) -> "PortArgs":
port = server_args.port + random.randint(100, 1000)
while True:
if is_port_available(port):
break
port += 42
return PortArgs(
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_port=port,
)
class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, {})
for lora_path in values:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
getattr(namespace, self.dest)[name] = path
else:
getattr(namespace, self.dest)[lora_path] = lora_path
class DeprecatedAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=0, **kwargs):
super(DeprecatedAction, self).__init__(
option_strings, dest, nargs=nargs, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)