[Minor] Many cleanup (#1357)
This commit is contained in:
@@ -76,6 +76,14 @@ class ServerArgs:
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
@@ -91,14 +99,6 @@ class ServerArgs:
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -385,6 +385,14 @@ class ServerArgs:
|
||||
)
|
||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
"--json-model-override-args",
|
||||
type=str,
|
||||
help="A dictionary in JSON string format used to override default model configurations.",
|
||||
default=ServerArgs.json_model_override_args,
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
@@ -459,22 +467,10 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
"--json-model-override-args",
|
||||
type=str,
|
||||
help="A dictionary in JSON string format used to override default model configurations.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.json_model_override_args = (
|
||||
json.loads(args.json_model_override_args)
|
||||
if args.json_model_override_args
|
||||
else None
|
||||
)
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
@@ -498,7 +494,7 @@ class ServerArgs:
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
Prepare the server arguments from the command line arguments.
|
||||
|
||||
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
raw_args = parser.parse_args(args)
|
||||
raw_args = parser.parse_args(argv)
|
||||
server_args = ServerArgs.from_cli_args(raw_args)
|
||||
return server_args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user