Clean up server_args.py (#7037)

This commit is contained in:
Lianmin Zheng
2025-06-10 05:34:29 -07:00
committed by GitHub
parent 019851d099
commit 6406408a70
7 changed files with 394 additions and 331 deletions

View File

@@ -118,7 +118,7 @@ def _compile_warning_1():
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Compile session. "
"And it may takes a long time(Typically 10-20 mins) "
"It may takes a long time (typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "

View File

@@ -72,32 +72,33 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS = [
"attention_backend",
"mm_attention_backend",
"debug_tensor_dump_inject",
"debug_tensor_dump_output_folder",
"chunked_prefill_size",
"deepep_mode",
"device",
"disable_chunked_prefix_cache",
"disable_radix_cache",
"enable_deepep_moe",
"enable_dp_attention",
"enable_two_batch_overlap",
"enable_dp_lm_head",
"enable_deepep_moe",
"deepep_mode",
"enable_ep_moe",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
"max_micro_batch_size",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_acc",
"speculative_accept_threshold_single",
"torchao_config",
"triton_attention_reduce_in_fp32",
"ep_num_redundant_experts",
"mm_attention_backend",
"num_reserved_decode_tokens",
]
# Put some global args for easy access

View File

@@ -17,12 +17,14 @@ from __future__ import annotations
import bisect
import inspect
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from torch.profiler import ProfilerActivity, profile
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_device_memory_capacity,
rank0_log,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -207,6 +212,9 @@ class CudaGraphRunner:
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
@@ -339,44 +347,67 @@ class CudaGraphRunner:
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
def capture(self):
def capture(self) -> None:
profile_context = empty_context()
if self.enable_profile_cuda_graph:
profile_context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
)
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with profile_context as prof:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for i, bs in enumerate(capture_range):
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
# Save gemlite cache after each capture
save_gemlite_cache()
# Save gemlite cache after each capture
save_gemlite_cache()
if self.enable_profile_cuda_graph:
log_message = (
"Sorted by CUDA Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=10
)
+ "\n\nSorted by CPU Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
)
logger.info(log_message)
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
@@ -443,7 +474,7 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
seq_lens_sum=seq_lens.sum().item(),
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,

View File

@@ -112,14 +112,12 @@ class ServerArgs:
file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Expert parallelism
ep_size: int = 1
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
nnodes: int = 1
@@ -138,6 +136,7 @@ class ServerArgs:
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
# Speculative decoding
speculative_algorithm: Optional[str] = None
@@ -149,28 +148,8 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: Optional[str] = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# Optimization/debug options
disable_radix_cache: bool = False
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
# Expert parallelism
ep_size: int = 1
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -187,10 +166,36 @@ class ServerArgs:
expert_distribution_recorder_buffer_size: Optional[int] = None
enable_expert_distribution_metrics: bool = False
deepep_config: Optional[str] = None
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
moe_dense_tp_size: Optional[int] = None
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: Optional[str] = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# Optimization/debug options
disable_radix_cache: bool = False
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False
enable_nccl_nvls: bool = False
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
disable_overlap_schedule: bool = False
disable_overlap_cg_plan: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_two_batch_overlap: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
@@ -201,29 +206,28 @@ class ServerArgs:
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
enable_custom_logit_processor: bool = False
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through_selective"
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None
warmups: Optional[str] = None
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
debug_tensor_dump_prefill_only: bool = False
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998
disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
def __post_init__(self):
@@ -390,7 +394,7 @@ class ServerArgs:
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
self.expert_distribution_recorder_mode = "stat"
logger.info(
f"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
)
if (self.enable_eplb or (self.init_expert_location is not None)) and (
@@ -398,7 +402,7 @@ class ServerArgs:
):
self.ep_dispatch_algorithm = "static"
logger.info(
f"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
)
if self.enable_expert_distribution_metrics and (
@@ -929,6 +933,13 @@ class ServerArgs:
default=ServerArgs.reasoning_parser,
help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
)
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
)
# Data parallelism
parser.add_argument(
@@ -949,15 +960,6 @@ class ServerArgs:
],
)
# Expert parallelism
parser.add_argument(
"--expert-parallel-size",
"--ep-size",
type=int,
default=ServerArgs.ep_size,
help="The expert parallelism size.",
)
# Multi-node distributed serving
parser.add_argument(
"--dist-init-addr",
@@ -1038,21 +1040,6 @@ class ServerArgs:
default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.",
)
parser.add_argument(
"--enable-flashinfer-mla",
action=DeprecatedAction,
help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
)
parser.add_argument(
"--enable-flashmla",
action=DeprecatedAction,
help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
action="store_true",
help="Not using ragged prefill wrapper when running flashinfer mla",
)
# Speculative decoding
parser.add_argument(
@@ -1102,236 +1089,32 @@ class ServerArgs:
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
"--mm-attention-backend",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
# Optimization/debug options
# Expert parallelism
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention for prefix caching.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
)
parser.add_argument(
"--enable-nccl-nvls",
action="store_true",
help="Enable NCCL NVLS for prefill heavy requests when available.",
)
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
)
parser.add_argument(
"--disable-outlines-disk-cache",
action="store_true",
help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mscclpp",
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
)
parser.add_argument(
"--enable-dp-attention",
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
)
parser.add_argument(
"--enable-dp-lm-head",
action="store_true",
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
"--expert-parallel-size",
"--ep-size",
type=int,
default=ServerArgs.ep_size,
help="The expert parallelism size.",
)
parser.add_argument(
"--enable-ep-moe",
action="store_true",
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-two-batch-overlap",
action="store_true",
help="Enabling two micro batches to overlap.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--torch-compile-max-bs",
type=int,
default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--cuda-graph-max-bs",
type=int,
default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
)
parser.add_argument(
"--cuda-graph-bs",
type=int,
nargs="+",
help="Set the list of batch sizes for cuda graph.",
)
parser.add_argument(
"--torchao-config",
type=str,
default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
)
parser.add_argument(
"--enable-nan-detection",
action="store_true",
help="Enable the NaN detection for debugging purposes.",
)
parser.add_argument(
"--enable-p2p-check",
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--triton-attention-num-kv-splits",
type=int,
default=ServerArgs.triton_attention_num_kv_splits,
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
parser.add_argument(
"--delete-ckpt-after-loading",
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
parser.add_argument(
"--enable-memory-saver",
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
parser.add_argument(
"--allow-auto-truncate",
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)
parser.add_argument(
"--enable-custom-logit-processor",
action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
)
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
)
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
parser.add_argument(
"--hicache-ratio",
type=float,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--moe-dense-tp-size",
type=int,
default=ServerArgs.moe_dense_tp_size,
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
)
parser.add_argument(
"--deepep-mode",
type=str,
@@ -1403,6 +1186,234 @@ class ServerArgs:
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
)
parser.add_argument(
"--moe-dense-tp-size",
type=int,
default=ServerArgs.moe_dense_tp_size,
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# Optimization/debug options
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention for prefix caching.",
)
parser.add_argument(
"--cuda-graph-max-bs",
type=int,
default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
)
parser.add_argument(
"--cuda-graph-bs",
type=int,
nargs="+",
help="Set the list of batch sizes for cuda graph.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
)
parser.add_argument(
"--enable-profile-cuda-graph",
action="store_true",
help="Enable profiling of cuda graph capture.",
)
parser.add_argument(
"--enable-nccl-nvls",
action="store_true",
help="Enable NCCL NVLS for prefill heavy requests when available.",
)
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
)
parser.add_argument(
"--disable-outlines-disk-cache",
action="store_true",
help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mscclpp",
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
)
parser.add_argument(
"--disable-overlap-cg-plan",
action="store_true",
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
)
parser.add_argument(
"--enable-dp-attention",
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
)
parser.add_argument(
"--enable-dp-lm-head",
action="store_true",
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
)
parser.add_argument(
"--enable-two-batch-overlap",
action="store_true",
help="Enabling two micro batches to overlap.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--torch-compile-max-bs",
type=int,
default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--torchao-config",
type=str,
default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
)
parser.add_argument(
"--enable-nan-detection",
action="store_true",
help="Enable the NaN detection for debugging purposes.",
)
parser.add_argument(
"--enable-p2p-check",
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--triton-attention-num-kv-splits",
type=int,
default=ServerArgs.triton_attention_num_kv_splits,
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
parser.add_argument(
"--delete-ckpt-after-loading",
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
parser.add_argument(
"--enable-memory-saver",
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
parser.add_argument(
"--allow-auto-truncate",
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)
parser.add_argument(
"--enable-custom-logit-processor",
action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
)
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
parser.add_argument(
"--hicache-ratio",
type=float,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
action="store_true",
help="Not using ragged prefill wrapper when running flashinfer mla",
)
parser.add_argument(
"--disable-shared-experts-fusion",
action="store_true",
@@ -1418,8 +1429,6 @@ class ServerArgs:
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
# Server warmups
parser.add_argument(
"--warmups",
type=str,
@@ -1447,6 +1456,11 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.",
)
parser.add_argument(
"--debug-tensor-dump-prefill-only",
action="store_true",
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
)
# Disaggregation
parser.add_argument(
@@ -1456,12 +1470,6 @@ class ServerArgs:
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
@@ -1469,6 +1477,12 @@ class ServerArgs:
choices=["mooncake", "nixl"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
parser.add_argument(
"--disaggregation-ib-device",
type=str,
@@ -1477,6 +1491,12 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
)
parser.add_argument(
"--num-reserved-decode-tokens",
type=int,
default=ServerArgs.num_reserved_decode_tokens,
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
)
parser.add_argument(
"--pdlb-url",
type=str,
@@ -1484,14 +1504,6 @@ class ServerArgs:
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size

View File

@@ -41,6 +41,9 @@ class EAGLEDraftCudaGraphRunner:
self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
server_args = model_runner.server_args
# Batch sizes to capture

View File

@@ -39,6 +39,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.topk = model_runner.server_args.speculative_eagle_topk
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1

View File

@@ -837,6 +837,7 @@ class CustomCacheManager(FileCacheManager):
def set_ulimit(target_soft_limit=65535):
# number of open files
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
@@ -846,6 +847,18 @@ def set_ulimit(target_soft_limit=65535):
except ValueError as e:
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
# stack size
resource_type = resource.RLIMIT_STACK
current_soft, current_hard = resource.getrlimit(resource_type)
target_soft_limit_stack_size = 1024 * target_soft_limit
if current_soft < target_soft_limit_stack_size:
try:
resource.setrlimit(
resource_type, (target_soft_limit_stack_size, current_hard)
)
except ValueError as e:
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
def add_api_key_middleware(app, api_key: str):
@app.middleware("http")