Clean up server_args.py to have a dedicated function for model specific adjustments (#8983)
This commit is contained in:
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
@@ -109,7 +108,7 @@ class ServerArgs:
|
||||
log_level: str = "info"
|
||||
log_level_http: Optional[str] = None
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
log_requests_level: int = 2
|
||||
crash_dump_folder: Optional[str] = None
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
@@ -131,6 +130,7 @@ class ServerArgs:
|
||||
enable_cache_report: bool = False
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_call_parser: Optional[str] = None
|
||||
tool_server: Optional[str] = None
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -278,15 +278,11 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# For tool server
|
||||
tool_server: Optional[str] = None
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Check deprecated arguments
|
||||
def print_deprecated_warning(message: str):
|
||||
logger.warning(f"\033[33m{message}\033[0m")
|
||||
@@ -392,6 +388,9 @@ class ServerArgs:
|
||||
self.attention_backend = "torch_native"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Model-specific adjustments
|
||||
self.model_specific_adjustments()
|
||||
|
||||
# Set kernel backends
|
||||
if self.device == "cpu":
|
||||
if self.attention_backend is None:
|
||||
@@ -470,55 +469,9 @@ class ServerArgs:
|
||||
"trtllm_mha backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
# default is triton, but we could have trtllm_mha as an option
|
||||
self.attention_backend = "triton"
|
||||
assert (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
quantization_config = getattr(
|
||||
self.get_hf_config(), "quantization_config", None
|
||||
)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
)
|
||||
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
logger.info(
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
if self.enable_triton_kernel_moe:
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.info(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
if is_mxfp4_quant_format:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
if self.attention_backend == "dual_chunk_flash_attn":
|
||||
logger.warning(
|
||||
"Mixed chunk is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Radix cache is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Cuda graph is disabled because of using dual chunk flash attention backend"
|
||||
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
self.enable_mixed_chunk = False
|
||||
self.disable_cuda_graph = True
|
||||
@@ -583,7 +536,7 @@ class ServerArgs:
|
||||
|
||||
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
|
||||
self.expert_distribution_recorder_mode = "stat"
|
||||
logger.info(
|
||||
logger.warning(
|
||||
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
|
||||
)
|
||||
|
||||
@@ -591,9 +544,6 @@ class ServerArgs:
|
||||
self.ep_dispatch_algorithm is None
|
||||
):
|
||||
self.ep_dispatch_algorithm = "static"
|
||||
logger.info(
|
||||
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
||||
@@ -1112,7 +1062,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--log-requests-level",
|
||||
type=int,
|
||||
default=0,
|
||||
default=ServerArgs.log_requests_level,
|
||||
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
||||
choices=[0, 1, 2, 3],
|
||||
)
|
||||
@@ -1245,6 +1195,12 @@ class ServerArgs:
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
@@ -1344,55 +1300,41 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Kernel backend
|
||||
ATTN_BACKENDS = [
|
||||
"aiter",
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"aiter",
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
],
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
default=ServerArgs.decode_attention_backend,
|
||||
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefill-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.prefill_attention_backend,
|
||||
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-attention-backend",
|
||||
type=str,
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.decode_attention_backend,
|
||||
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
@@ -1612,7 +1554,6 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_mem_layout,
|
||||
help="The layout of host memory pool for hierarchical cache.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
@@ -1985,14 +1926,6 @@ class ServerArgs:
|
||||
help="Disable mmap while loading weight using safetensors.",
|
||||
)
|
||||
|
||||
# For tool server
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
@@ -2056,25 +1989,6 @@ class ServerArgs:
|
||||
None,
|
||||
}, "moe_dense_tp_size only support 1 and None currently"
|
||||
|
||||
# Check model architecture
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
|
||||
if model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForCausalLM",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
]:
|
||||
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
||||
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
||||
logger.warning(
|
||||
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
# Check LoRA
|
||||
self.check_lora_server_args()
|
||||
|
||||
@@ -2100,7 +2014,7 @@ class ServerArgs:
|
||||
if self.lora_paths:
|
||||
if self.enable_lora is None:
|
||||
self.enable_lora = True
|
||||
logger.info(
|
||||
logger.warning(
|
||||
"--enable-lora is set to True because --lora-paths is provided."
|
||||
)
|
||||
elif self.enable_lora is False:
|
||||
@@ -2172,6 +2086,58 @@ class ServerArgs:
|
||||
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||
)
|
||||
|
||||
def model_specific_adjustments(self):
|
||||
hf_config = self.get_hf_config()
|
||||
model_arch = hf_config.architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "triton"
|
||||
assert self.attention_backend in [
|
||||
"triton",
|
||||
"trtllm_mha",
|
||||
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
|
||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
)
|
||||
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
logger.warning(
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
if self.enable_triton_kernel_moe:
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.warning(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
if is_mxfp4_quant_format:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
elif "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
elif model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForCausalLM",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
]:
|
||||
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
||||
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
||||
logger.warning(
|
||||
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
def adjust_mem_fraction_for_vlm(self, model_config):
|
||||
vision_config = getattr(model_config.hf_config, "vision_config", None)
|
||||
if vision_config is None:
|
||||
|
||||
Reference in New Issue
Block a user