[server] Passing model_override_args to launch_server via the CLI. (#1298)
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
This commit is contained in:
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
@@ -95,6 +96,9 @@ class ServerArgs:
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -455,10 +459,22 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
"--json-model-override-args",
|
||||
type=str,
|
||||
help="A dictionary in JSON string format used to override default model configurations.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.json_model_override_args = (
|
||||
json.loads(args.json_model_override_args)
|
||||
if args.json_model_override_args
|
||||
else None
|
||||
)
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
@@ -482,6 +498,24 @@ class ServerArgs:
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||
"""
|
||||
Prepare the server arguments from the command line arguments.
|
||||
|
||||
Args:
|
||||
args: The command line arguments. Typically, it should be `sys.argv[1:]`
|
||||
to ensure compatibility with `parse_args` when no arguments are passed.
|
||||
|
||||
Returns:
|
||||
The server arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
raw_args = parser.parse_args(args)
|
||||
server_args = ServerArgs.from_cli_args(raw_args)
|
||||
return server_args
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
tokenizer_port: int
|
||||
|
||||
Reference in New Issue
Block a user