[server] Passing model_override_args to launch_server via the CLI. (#1298)

Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
This commit is contained in:
Kai-Hsun Chen
2024-09-09 02:14:25 -07:00
committed by GitHub
parent 662ecd9368
commit c9b75917d5
8 changed files with 71 additions and 16 deletions

View File

@@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang # Launch sglang
# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# offline # offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11

View File

@@ -480,6 +480,7 @@ def main(server_args, bench_args):
if __name__ == "__main__": if __name__ == "__main__":
# TODO(kevin85421): Make the parser setup unit testable.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser)

View File

@@ -1,20 +1,18 @@
"""Launch the inference server.""" """Launch the inference server."""
import argparse
import os import os
import sys
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() server_args = prepare_server_args(sys.argv[1:])
ServerArgs.add_cli_args(parser) model_override_args = server_args.json_model_override_args
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
try: try:
launch_server(server_args) launch_server(server_args, model_override_args=model_override_args)
except Exception as e: except Exception as e:
raise e raise e
finally: finally:

View File

@@ -1,14 +1,11 @@
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import argparse import sys
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import launch_server, prepare_server_args
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() server_args = prepare_server_args(sys.argv[1:])
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
model_override_args = {} model_override_args = {}
model_override_args["mm_spatial_pool_stride"] = 2 model_override_args["mm_spatial_pool_stride"] = 2
@@ -20,7 +17,7 @@ if __name__ == "__main__":
model_override_args["max_sequence_length"] = 4096 * 2 model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower(): if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002 model_override_args["image_token_index"] = 64002
launch_server(server_args, model_override_args, None) launch_server(server_args, model_override_args, None)

View File

@@ -17,6 +17,7 @@ limitations under the License.
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import random import random
from typing import List, Optional, Union from typing import List, Optional, Union
@@ -95,6 +96,9 @@ class ServerArgs:
nnodes: int = 1 nnodes: int = 1
node_rank: Optional[int] = None node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
@@ -455,10 +459,22 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
) )
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs}) return cls(**{attr: getattr(args, attr) for attr in attrs})
@@ -482,6 +498,24 @@ class ServerArgs:
self.disable_flashinfer = False self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
Args:
args: The command line arguments. Typically, it should be `sys.argv[1:]`
to ensure compatibility with `parse_args` when no arguments are passed.
Returns:
The server arguments.
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
tokenizer_port: int tokenizer_port: int

View File

@@ -19,6 +19,7 @@ suites = {
"test_triton_attn_backend.py", "test_triton_attn_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_server_args.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True

View File

@@ -0,0 +1,24 @@
import unittest
from sglang.srt.server_args import prepare_server_args
class TestPrepareServerArgs(unittest.TestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
"--model-path",
"model_path",
"--json-model-override-args",
'{"rope_scaling": {"factor": 2.0, "type": "linear"}}',
]
)
self.assertEqual(server_args.model_path, "model_path")
self.assertEqual(
server_args.json_model_override_args,
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
)
if __name__ == "__main__":
unittest.main()

View File

@@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase):
"python3", "python3",
"-m", "-m",
"sglang.bench_latency", "sglang.bench_latency",
"--model", "--model-path",
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
"--batch-size", "--batch-size",
"1", "1",