[server] Passing model_override_args to launch_server via the CLI. (#1298)
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
This commit is contained in:
@@ -6,7 +6,7 @@
|
|||||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||||
|
|
||||||
# Launch sglang
|
# Launch sglang
|
||||||
# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
|
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
|
||||||
|
|
||||||
# offline
|
# offline
|
||||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
||||||
|
|||||||
@@ -480,6 +480,7 @@ def main(server_args, bench_args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# TODO(kevin85421): Make the parser setup unit testable.
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ServerArgs.add_cli_args(parser)
|
ServerArgs.add_cli_args(parser)
|
||||||
BenchArgs.add_cli_args(parser)
|
BenchArgs.add_cli_args(parser)
|
||||||
|
|||||||
@@ -1,20 +1,18 @@
|
|||||||
"""Launch the inference server."""
|
"""Launch the inference server."""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import prepare_server_args
|
||||||
from sglang.srt.utils import kill_child_process
|
from sglang.srt.utils import kill_child_process
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
ServerArgs.add_cli_args(parser)
|
model_override_args = server_args.json_model_override_args
|
||||||
args = parser.parse_args()
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
launch_server(server_args)
|
launch_server(server_args, model_override_args=model_override_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
"""Launch the inference server for Llava-video model."""
|
"""Launch the inference server for Llava-video model."""
|
||||||
|
|
||||||
import argparse
|
import sys
|
||||||
|
|
||||||
from sglang.srt.server import ServerArgs, launch_server
|
from sglang.srt.server import launch_server, prepare_server_args
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
ServerArgs.add_cli_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
|
||||||
|
|
||||||
model_override_args = {}
|
model_override_args = {}
|
||||||
model_override_args["mm_spatial_pool_stride"] = 2
|
model_override_args["mm_spatial_pool_stride"] = 2
|
||||||
@@ -20,7 +17,7 @@ if __name__ == "__main__":
|
|||||||
model_override_args["max_sequence_length"] = 4096 * 2
|
model_override_args["max_sequence_length"] = 4096 * 2
|
||||||
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
||||||
model_override_args["model_max_length"] = 4096 * 2
|
model_override_args["model_max_length"] = 4096 * 2
|
||||||
if "34b" in args.model_path.lower():
|
if "34b" in server_args.model_path.lower():
|
||||||
model_override_args["image_token_index"] = 64002
|
model_override_args["image_token_index"] = 64002
|
||||||
|
|
||||||
launch_server(server_args, model_override_args, None)
|
launch_server(server_args, model_override_args, None)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
@@ -95,6 +96,9 @@ class ServerArgs:
|
|||||||
nnodes: int = 1
|
nnodes: int = 1
|
||||||
node_rank: Optional[int] = None
|
node_rank: Optional[int] = None
|
||||||
|
|
||||||
|
# Model override args in JSON
|
||||||
|
json_model_override_args: Optional[dict] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
@@ -455,10 +459,22 @@ class ServerArgs:
|
|||||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model override args
|
||||||
|
parser.add_argument(
|
||||||
|
"--json-model-override-args",
|
||||||
|
type=str,
|
||||||
|
help="A dictionary in JSON string format used to override default model configurations.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
args.dp_size = args.data_parallel_size
|
args.dp_size = args.data_parallel_size
|
||||||
|
args.json_model_override_args = (
|
||||||
|
json.loads(args.json_model_override_args)
|
||||||
|
if args.json_model_override_args
|
||||||
|
else None
|
||||||
|
)
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
@@ -482,6 +498,24 @@ class ServerArgs:
|
|||||||
self.disable_flashinfer = False
|
self.disable_flashinfer = False
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||||
|
"""
|
||||||
|
Prepare the server arguments from the command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: The command line arguments. Typically, it should be `sys.argv[1:]`
|
||||||
|
to ensure compatibility with `parse_args` when no arguments are passed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The server arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
raw_args = parser.parse_args(args)
|
||||||
|
server_args = ServerArgs.from_cli_args(raw_args)
|
||||||
|
return server_args
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
tokenizer_port: int
|
tokenizer_port: int
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ suites = {
|
|||||||
"test_triton_attn_backend.py",
|
"test_triton_attn_backend.py",
|
||||||
"test_update_weights.py",
|
"test_update_weights.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
|
"test_server_args.py",
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
"sampling/penaltylib": glob.glob(
|
||||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||||
|
|||||||
24
test/srt/test_server_args.py
Normal file
24
test/srt/test_server_args.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang.srt.server_args import prepare_server_args
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrepareServerArgs(unittest.TestCase):
|
||||||
|
def test_prepare_server_args(self):
|
||||||
|
server_args = prepare_server_args(
|
||||||
|
[
|
||||||
|
"--model-path",
|
||||||
|
"model_path",
|
||||||
|
"--json-model-override-args",
|
||||||
|
'{"rope_scaling": {"factor": 2.0, "type": "linear"}}',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(server_args.model_path, "model_path")
|
||||||
|
self.assertEqual(
|
||||||
|
server_args.json_model_override_args,
|
||||||
|
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase):
|
|||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.bench_latency",
|
"sglang.bench_latency",
|
||||||
"--model",
|
"--model-path",
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
"1",
|
"1",
|
||||||
|
|||||||
Reference in New Issue
Block a user