Clean up server_args.py to have a dedicated function for model specific adjustments (#8983)
This commit is contained in:
@@ -21,6 +21,7 @@ runtime_common = [
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"datasets",
|
||||
"einops",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
@@ -29,6 +30,7 @@ runtime_common = [
|
||||
"modelscope",
|
||||
"msgspec",
|
||||
"ninja",
|
||||
"openai==1.99.1",
|
||||
"openai-harmony==0.0.3",
|
||||
"orjson",
|
||||
"outlines==0.1.11",
|
||||
@@ -48,6 +50,7 @@ runtime_common = [
|
||||
"torchao==0.9.0",
|
||||
"transformers==4.55.0",
|
||||
"timm==1.0.16",
|
||||
"tiktoken",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.22",
|
||||
@@ -60,7 +63,6 @@ srt = [
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"einops",
|
||||
"flashinfer_python==0.2.10",
|
||||
]
|
||||
|
||||
@@ -71,10 +73,7 @@ blackwell = [
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"einops",
|
||||
"flashinfer_python==0.2.10",
|
||||
"tiktoken",
|
||||
"openai==1.99.1",
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
@@ -101,7 +100,7 @@ srt_npu = ["sglang[runtime_common]"]
|
||||
openai = ["openai==1.99.1", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
|
||||
torch_memory_saver = ["torch_memory_saver==0.0.8"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
|
||||
@@ -64,13 +64,12 @@ class ModelConfig:
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.model_impl = model_impl
|
||||
|
||||
# Parse args
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
@@ -139,6 +138,7 @@ class ModelConfig:
|
||||
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
@@ -282,12 +282,10 @@ class ModelConfig:
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
config = self.hf_config
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
|
||||
config, "image_token_index", None
|
||||
)
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||
|
||||
@@ -9,8 +9,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
except ImportError:
|
||||
logger.warning("Ignoring mcp import error")
|
||||
except ImportError as e:
|
||||
mcp = e
|
||||
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
@@ -109,7 +108,7 @@ class ServerArgs:
|
||||
log_level: str = "info"
|
||||
log_level_http: Optional[str] = None
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
log_requests_level: int = 2
|
||||
crash_dump_folder: Optional[str] = None
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
@@ -131,6 +130,7 @@ class ServerArgs:
|
||||
enable_cache_report: bool = False
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_call_parser: Optional[str] = None
|
||||
tool_server: Optional[str] = None
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -278,15 +278,11 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# For tool server
|
||||
tool_server: Optional[str] = None
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Check deprecated arguments
|
||||
def print_deprecated_warning(message: str):
|
||||
logger.warning(f"\033[33m{message}\033[0m")
|
||||
@@ -392,6 +388,9 @@ class ServerArgs:
|
||||
self.attention_backend = "torch_native"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Model-specific adjustments
|
||||
self.model_specific_adjustments()
|
||||
|
||||
# Set kernel backends
|
||||
if self.device == "cpu":
|
||||
if self.attention_backend is None:
|
||||
@@ -470,55 +469,9 @@ class ServerArgs:
|
||||
"trtllm_mha backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
# default is triton, but we could have trtllm_mha as an option
|
||||
self.attention_backend = "triton"
|
||||
assert (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.attention_backend == "triton"
|
||||
)
|
||||
quantization_config = getattr(
|
||||
self.get_hf_config(), "quantization_config", None
|
||||
)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
)
|
||||
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
logger.info(
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
if self.enable_triton_kernel_moe:
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.info(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
if is_mxfp4_quant_format:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
if self.attention_backend == "dual_chunk_flash_attn":
|
||||
logger.warning(
|
||||
"Mixed chunk is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Radix cache is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Cuda graph is disabled because of using dual chunk flash attention backend"
|
||||
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
self.enable_mixed_chunk = False
|
||||
self.disable_cuda_graph = True
|
||||
@@ -583,7 +536,7 @@ class ServerArgs:
|
||||
|
||||
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
|
||||
self.expert_distribution_recorder_mode = "stat"
|
||||
logger.info(
|
||||
logger.warning(
|
||||
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
|
||||
)
|
||||
|
||||
@@ -591,9 +544,6 @@ class ServerArgs:
|
||||
self.ep_dispatch_algorithm is None
|
||||
):
|
||||
self.ep_dispatch_algorithm = "static"
|
||||
logger.info(
|
||||
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
||||
@@ -1112,7 +1062,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--log-requests-level",
|
||||
type=int,
|
||||
default=0,
|
||||
default=ServerArgs.log_requests_level,
|
||||
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
||||
choices=[0, 1, 2, 3],
|
||||
)
|
||||
@@ -1245,6 +1195,12 @@ class ServerArgs:
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
@@ -1344,55 +1300,41 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Kernel backend
|
||||
ATTN_BACKENDS = [
|
||||
"aiter",
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"aiter",
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
],
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
default=ServerArgs.decode_attention_backend,
|
||||
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefill-attention-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
],
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.prefill_attention_backend,
|
||||
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-attention-backend",
|
||||
type=str,
|
||||
choices=ATTN_BACKENDS,
|
||||
default=ServerArgs.decode_attention_backend,
|
||||
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
@@ -1612,7 +1554,6 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_mem_layout,
|
||||
help="The layout of host memory pool for hierarchical cache.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
@@ -1985,14 +1926,6 @@ class ServerArgs:
|
||||
help="Disable mmap while loading weight using safetensors.",
|
||||
)
|
||||
|
||||
# For tool server
|
||||
parser.add_argument(
|
||||
"--tool-server",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
@@ -2056,25 +1989,6 @@ class ServerArgs:
|
||||
None,
|
||||
}, "moe_dense_tp_size only support 1 and None currently"
|
||||
|
||||
# Check model architecture
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
|
||||
if model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForCausalLM",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
]:
|
||||
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
||||
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
||||
logger.warning(
|
||||
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
# Check LoRA
|
||||
self.check_lora_server_args()
|
||||
|
||||
@@ -2100,7 +2014,7 @@ class ServerArgs:
|
||||
if self.lora_paths:
|
||||
if self.enable_lora is None:
|
||||
self.enable_lora = True
|
||||
logger.info(
|
||||
logger.warning(
|
||||
"--enable-lora is set to True because --lora-paths is provided."
|
||||
)
|
||||
elif self.enable_lora is False:
|
||||
@@ -2172,6 +2086,58 @@ class ServerArgs:
|
||||
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||
)
|
||||
|
||||
def model_specific_adjustments(self):
|
||||
hf_config = self.get_hf_config()
|
||||
model_arch = hf_config.architectures[0]
|
||||
if model_arch in ["GptOssForCausalLM"]:
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "triton"
|
||||
assert self.attention_backend in [
|
||||
"triton",
|
||||
"trtllm_mha",
|
||||
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}"
|
||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||
is_mxfp4_quant_format = (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
)
|
||||
|
||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||
self.enable_flashinfer_mxfp4_moe = True
|
||||
self.enable_triton_kernel_moe = False
|
||||
logger.warning(
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
if self.enable_triton_kernel_moe:
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.warning(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
if is_mxfp4_quant_format:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
elif "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
elif model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForCausalLM",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
]:
|
||||
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
||||
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
||||
logger.warning(
|
||||
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
||||
)
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
def adjust_mem_fraction_for_vlm(self, model_config):
|
||||
vision_config = getattr(model_config.hf_config, "vision_config", None)
|
||||
if vision_config is None:
|
||||
|
||||
Reference in New Issue
Block a user