diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index eae5e2206..d31f8daf8 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 9006b7150..611349577 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -480,6 +480,7 @@ def main(server_args, bench_args): if __name__ == "__main__": + # TODO(kevin85421): Make the parser setup unit testable. parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 1df64e848..06aa140d9 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,20 +1,18 @@ """Launch the inference server.""" -import argparse import os +import sys from sglang.srt.server import launch_server -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_child_process if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) + model_override_args = server_args.json_model_override_args try: - launch_server(server_args) + launch_server(server_args, model_override_args=model_override_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 43eefef4e..6b8d151ee 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,14 +1,11 @@ """Launch the inference server for Llava-video model.""" -import argparse +import sys -from sglang.srt.server import ServerArgs, launch_server +from sglang.srt.server import launch_server, prepare_server_args if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) model_override_args = {} model_override_args["mm_spatial_pool_stride"] = 2 @@ -20,7 +17,7 @@ if __name__ == "__main__": model_override_args["max_sequence_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2 - if "34b" in args.model_path.lower(): + if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 launch_server(server_args, model_override_args, None) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a56c02e1..e21f02108 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,6 +17,7 @@ limitations under the License. import argparse import dataclasses +import json import logging import random from typing import List, Optional, Union @@ -95,6 +96,9 @@ class ServerArgs: nnodes: int = 1 node_rank: Optional[int] = None + # Model override args in JSON + json_model_override_args: Optional[dict] = None + def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -455,10 +459,22 @@ class ServerArgs: help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.json_model_override_args = ( + json.loads(args.json_model_override_args) + if args.json_model_override_args + else None + ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -482,6 +498,24 @@ class ServerArgs: self.disable_flashinfer = False +def prepare_server_args(args: argparse.Namespace) -> ServerArgs: + """ + Prepare the server arguments from the command line arguments. + + Args: + args: The command line arguments. Typically, it should be `sys.argv[1:]` + to ensure compatibility with `parse_args` when no arguments are passed. + + Returns: + The server arguments. + """ + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + raw_args = parser.parse_args(args) + server_args = ServerArgs.from_cli_args(raw_args) + return server_args + + @dataclasses.dataclass class PortArgs: tokenizer_port: int diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cafcf3f2d..d5982844c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ suites = { "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_server_args.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py new file mode 100644 index 000000000..71129e3eb --- /dev/null +++ b/test/srt/test_server_args.py @@ -0,0 +1,24 @@ +import unittest + +from sglang.srt.server_args import prepare_server_args + + +class TestPrepareServerArgs(unittest.TestCase): + def test_prepare_server_args(self): + server_args = prepare_server_args( + [ + "--model-path", + "model_path", + "--json-model-override-args", + '{"rope_scaling": {"factor": 2.0, "type": "linear"}}', + ] + ) + self.assertEqual(server_args.model_path, "model_path") + self.assertEqual( + server_args.json_model_override_args, + {"rope_scaling": {"factor": 2.0, "type": "linear"}}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_serving_latency.py b/test/srt/test_serving_latency.py index e762892c8..3dae4541a 100644 --- a/test/srt/test_serving_latency.py +++ b/test/srt/test_serving_latency.py @@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase): "python3", "-m", "sglang.bench_latency", - "--model", + "--model-path", DEFAULT_MODEL_NAME_FOR_TEST, "--batch-size", "1",