[Minor] Many cleanup (#1357)

This commit is contained in:
Lianmin Zheng
2024-09-09 04:14:11 -07:00
committed by GitHub
parent c9b75917d5
commit e4d68afcf0
24 changed files with 416 additions and 296 deletions

View File

@@ -76,6 +76,14 @@ class ServerArgs:
dp_size: int = 1
load_balance_method: str = "round_robin"
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: str = "{}"
# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@@ -385,6 +385,14 @@ class ServerArgs:
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args,
)
# Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
@@ -459,22 +467,10 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
@@ -498,7 +494,7 @@ class ServerArgs:
self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args