Clean up server_args.py to have a dedicated function for model specific adjustments (#8983)

This commit is contained in:
Lianmin Zheng
2025-08-08 19:56:50 -07:00
committed by GitHub
parent 23f2afb2ce
commit 706bd69cc5
24 changed files with 201 additions and 340 deletions

View File

@@ -37,7 +37,6 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
@@ -109,7 +108,7 @@ class ServerArgs:
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
log_requests_level: int = 0
log_requests_level: int = 2
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False
enable_metrics: bool = False
@@ -131,6 +130,7 @@ class ServerArgs:
enable_cache_report: bool = False
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
tool_server: Optional[str] = None
# Data parallelism
dp_size: int = 1
@@ -278,15 +278,11 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# For tool server
tool_server: Optional[str] = None
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
def print_deprecated_warning(message: str):
logger.warning(f"\033[33m{message}\033[0m")
@@ -392,6 +388,9 @@ class ServerArgs:
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"
# Model-specific adjustments
self.model_specific_adjustments()
# Set kernel backends
if self.device == "cpu":
if self.attention_backend is None:
@@ -470,55 +469,9 @@ class ServerArgs:
"trtllm_mha backend does not support speculative decoding yet."
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
# default is triton, but we could have trtllm_mha as an option
self.attention_backend = "triton"
assert (
self.attention_backend == "trtllm_mha"
or self.attention_backend == "triton"
)
quantization_config = getattr(
self.get_hf_config(), "quantization_config", None
)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.info(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Cuda graph is disabled because of using dual chunk flash attention backend"
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
)
self.enable_mixed_chunk = False
self.disable_cuda_graph = True
@@ -583,7 +536,7 @@ class ServerArgs:
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
self.expert_distribution_recorder_mode = "stat"
logger.info(
logger.warning(
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
)
@@ -591,9 +544,6 @@ class ServerArgs:
self.ep_dispatch_algorithm is None
):
self.ep_dispatch_algorithm = "static"
logger.info(
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
)
if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None
@@ -1112,7 +1062,7 @@ class ServerArgs:
parser.add_argument(
"--log-requests-level",
type=int,
default=0,
default=ServerArgs.log_requests_level,
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
choices=[0, 1, 2, 3],
)
@@ -1245,6 +1195,12 @@ class ServerArgs:
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
)
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Data parallelism
parser.add_argument(
@@ -1344,55 +1300,41 @@ class ServerArgs:
)
# Kernel backend
ATTN_BACKENDS = [
"aiter",
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
]
parser.add_argument(
"--attention-backend",
type=str,
choices=[
"aiter",
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
],
choices=ATTN_BACKENDS,
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--prefill-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
choices=ATTN_BACKENDS,
default=ServerArgs.prefill_attention_backend,
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=ATTN_BACKENDS,
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--sampling-backend",
type=str,
@@ -1612,7 +1554,6 @@ class ServerArgs:
default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.",
)
parser.add_argument(
"--hicache-storage-backend",
type=str,
@@ -1985,14 +1926,6 @@ class ServerArgs:
help="Disable mmap while loading weight using safetensors.",
)
# For tool server
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
@@ -2056,25 +1989,6 @@ class ServerArgs:
None,
}, "moe_dense_tp_size only support 1 and None currently"
# Check model architecture
model_arch = self.get_hf_config().architectures[0]
if "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
if model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
# Check LoRA
self.check_lora_server_args()
@@ -2100,7 +2014,7 @@ class ServerArgs:
if self.lora_paths:
if self.enable_lora is None:
self.enable_lora = True
logger.info(
logger.warning(
"--enable-lora is set to True because --lora-paths is provided."
)
elif self.enable_lora is False:
@@ -2172,6 +2086,58 @@ class ServerArgs:
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
def model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
self.attention_backend = "triton"
assert self.attention_backend in [
"triton",
"trtllm_mha",
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
if vision_config is None: