diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 757ae295a..9ed2b5177 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -48,6 +48,80 @@ from sglang.srt.utils import ( logger = logging.getLogger(__name__) +# Define constants +LOAD_FORMAT_CHOICES = [ + "auto", + "pt", + "safetensors", + "npcache", + "dummy", + "sharded_state", + "gguf", + "bitsandbytes", + "layered", + "remote", +] + +QUANTIZATION_CHOICES = [ + "awq", + "fp8", + "gptq", + "marlin", + "gptq_marlin", + "awq_marlin", + "bitsandbytes", + "gguf", + "modelopt", + "modelopt_fp4", + "petit_nvfp4", + "w8a8_int8", + "w8a8_fp8", + "moe_wna16", + "qoq", + "w4afp8", + "mxfp4", +] + +ATTENTION_BACKEND_CHOICES = [ + # Common + "triton", + "torch_native", + # NVIDIA specific + "cutlass_mla", + "fa3", + "flashinfer", + "flashmla", + "trtllm_mla", + "trtllm_mha", + "dual_chunk_flash_attn", + # AMD specific + "aiter", + "wave", + # Other platforms + "intel_amx", + "ascend", +] + +DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] + + +# Allow external code to add more choices +def add_load_format_choices(choices): + LOAD_FORMAT_CHOICES.extend(choices) + + +def add_quantization_method_choices(choices): + QUANTIZATION_CHOICES.extend(choices) + + +def add_attention_backend_choices(choices): + ATTENTION_BACKEND_CHOICES.extend(choices) + + +def add_disagg_transfer_backend_choices(choices): + DISAGG_TRANSFER_BACKEND_CHOICES.extend(choices) + + @dataclasses.dataclass class ServerArgs: # Model and tokenizer @@ -761,18 +835,7 @@ class ServerArgs: "--load-format", type=str, default=ServerArgs.load_format, - choices=[ - "auto", - "pt", - "safetensors", - "npcache", - "dummy", - "sharded_state", - "gguf", - "bitsandbytes", - "layered", - "remote", - ], + choices=LOAD_FORMAT_CHOICES, help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -891,25 +954,7 @@ class ServerArgs: "--quantization", type=str, default=ServerArgs.quantization, - choices=[ - "awq", - "fp8", - "gptq", - "marlin", - "gptq_marlin", - "awq_marlin", - "bitsandbytes", - "gguf", - "modelopt", - "modelopt_fp4", - "petit_nvfp4", - "w8a8_int8", - "w8a8_fp8", - "moe_wna16", - "qoq", - "w4afp8", - "mxfp4", - ], + choices=QUANTIZATION_CHOICES, help="The quantization method.", ) parser.add_argument( @@ -1359,43 +1404,24 @@ class ServerArgs: ) # Kernel backend - ATTN_BACKENDS = [ - # Common - "triton", - "torch_native", - # NVIDIA specific - "cutlass_mla", - "fa3", - "flashinfer", - "flashmla", - "trtllm_mla", - "trtllm_mha", - "dual_chunk_flash_attn", - # AMD specific - "aiter", - "wave", - # Other platforms - "intel_amx", - "ascend", - ] parser.add_argument( "--attention-backend", type=str, - choices=ATTN_BACKENDS, + choices=ATTENTION_BACKEND_CHOICES, default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) parser.add_argument( "--prefill-attention-backend", type=str, - choices=ATTN_BACKENDS, + choices=ATTENTION_BACKEND_CHOICES, default=ServerArgs.prefill_attention_backend, help="Choose the kernels for prefill attention layers (have priority over --attention-backend).", ) parser.add_argument( "--decode-attention-backend", type=str, - choices=ATTN_BACKENDS, + choices=ATTENTION_BACKEND_CHOICES, default=ServerArgs.decode_attention_backend, help="Choose the kernels for decode attention layers (have priority over --attention-backend).", ) @@ -1959,7 +1985,7 @@ class ServerArgs: "--disaggregation-transfer-backend", type=str, default=ServerArgs.disaggregation_transfer_backend, - choices=["mooncake", "nixl", "ascend"], + choices=DISAGG_TRANSFER_BACKEND_CHOICES, help="The backend for disaggregation transfer. Default is mooncake.", ) parser.add_argument(