Clean up server_args.py to have a dedicated function for model specific adjustments (#8983)

This commit is contained in:
Lianmin Zheng
2025-08-08 19:56:50 -07:00
committed by GitHub
parent 23f2afb2ce
commit 706bd69cc5
24 changed files with 201 additions and 340 deletions

View File

@@ -21,6 +21,7 @@ runtime_common = [
"build",
"compressed-tensors",
"datasets",
"einops",
"fastapi",
"hf_transfer",
"huggingface_hub",
@@ -29,6 +30,7 @@ runtime_common = [
"modelscope",
"msgspec",
"ninja",
"openai==1.99.1",
"openai-harmony==0.0.3",
"orjson",
"outlines==0.1.11",
@@ -48,6 +50,7 @@ runtime_common = [
"torchao==0.9.0",
"transformers==4.55.0",
"timm==1.0.16",
"tiktoken",
"uvicorn",
"uvloop",
"xgrammar==0.1.22",
@@ -60,7 +63,6 @@ srt = [
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"einops",
"flashinfer_python==0.2.10",
]
@@ -71,10 +73,7 @@ blackwell = [
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"einops",
"flashinfer_python==0.2.10",
"tiktoken",
"openai==1.99.1",
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
@@ -101,7 +100,7 @@ srt_npu = ["sglang[runtime_common]"]
openai = ["openai==1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
torch_memory_saver = ["torch_memory_saver==0.0.8"]
decord = ["decord"]
test = [
"accelerate",

View File

@@ -64,13 +64,12 @@ class ModelConfig:
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.model_impl = model_impl
# Parse args
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
@@ -139,6 +138,7 @@ class ModelConfig:
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
):
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
@@ -282,12 +282,10 @@ class ModelConfig:
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
config = self.hf_config
# multimodal
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
config, "image_token_index", None
)
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):

View File

@@ -9,8 +9,8 @@ logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
except ImportError:
logger.warning("Ignoring mcp import error")
except ImportError as e:
mcp = e
from openai_harmony import Author, Message, Role, StreamState, TextContent

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Union
from typing import Optional, Union
import torch

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import builtins
import inspect
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, Optional, Type
import torch

View File

@@ -37,7 +37,6 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
@@ -109,7 +108,7 @@ class ServerArgs:
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
log_requests_level: int = 0
log_requests_level: int = 2
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False
enable_metrics: bool = False
@@ -131,6 +130,7 @@ class ServerArgs:
enable_cache_report: bool = False
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
tool_server: Optional[str] = None
# Data parallelism
dp_size: int = 1
@@ -278,15 +278,11 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# For tool server
tool_server: Optional[str] = None
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
def print_deprecated_warning(message: str):
logger.warning(f"\033[33m{message}\033[0m")
@@ -392,6 +388,9 @@ class ServerArgs:
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"
# Model-specific adjustments
self.model_specific_adjustments()
# Set kernel backends
if self.device == "cpu":
if self.attention_backend is None:
@@ -470,55 +469,9 @@ class ServerArgs:
"trtllm_mha backend does not support speculative decoding yet."
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
# default is triton, but we could have trtllm_mha as an option
self.attention_backend = "triton"
assert (
self.attention_backend == "trtllm_mha"
or self.attention_backend == "triton"
)
quantization_config = getattr(
self.get_hf_config(), "quantization_config", None
)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.info(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Cuda graph is disabled because of using dual chunk flash attention backend"
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
)
self.enable_mixed_chunk = False
self.disable_cuda_graph = True
@@ -583,7 +536,7 @@ class ServerArgs:
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
self.expert_distribution_recorder_mode = "stat"
logger.info(
logger.warning(
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
)
@@ -591,9 +544,6 @@ class ServerArgs:
self.ep_dispatch_algorithm is None
):
self.ep_dispatch_algorithm = "static"
logger.info(
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
)
if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None
@@ -1112,7 +1062,7 @@ class ServerArgs:
parser.add_argument(
"--log-requests-level",
type=int,
default=0,
default=ServerArgs.log_requests_level,
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
choices=[0, 1, 2, 3],
)
@@ -1245,6 +1195,12 @@ class ServerArgs:
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
)
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Data parallelism
parser.add_argument(
@@ -1344,55 +1300,41 @@ class ServerArgs:
)
# Kernel backend
ATTN_BACKENDS = [
"aiter",
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
]
parser.add_argument(
"--attention-backend",
type=str,
choices=[
"aiter",
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
],
choices=ATTN_BACKENDS,
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--prefill-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
choices=ATTN_BACKENDS,
default=ServerArgs.prefill_attention_backend,
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=ATTN_BACKENDS,
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--sampling-backend",
type=str,
@@ -1612,7 +1554,6 @@ class ServerArgs:
default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.",
)
parser.add_argument(
"--hicache-storage-backend",
type=str,
@@ -1985,14 +1926,6 @@ class ServerArgs:
help="Disable mmap while loading weight using safetensors.",
)
# For tool server
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
@@ -2056,25 +1989,6 @@ class ServerArgs:
None,
}, "moe_dense_tp_size only support 1 and None currently"
# Check model architecture
model_arch = self.get_hf_config().architectures[0]
if "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
if model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
# Check LoRA
self.check_lora_server_args()
@@ -2100,7 +2014,7 @@ class ServerArgs:
if self.lora_paths:
if self.enable_lora is None:
self.enable_lora = True
logger.info(
logger.warning(
"--enable-lora is set to True because --lora-paths is provided."
)
elif self.enable_lora is False:
@@ -2172,6 +2086,58 @@ class ServerArgs:
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
def model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
self.attention_backend = "triton"
assert self.attention_backend in [
"triton",
"trtllm_mha",
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = (
quantization_config is not None
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
elif "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
elif model_arch in [
"Gemma2ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForCausalLM",
"Gemma3nForConditionalGeneration",
]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
logger.warning(
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
if vision_config is None: