Clean up (#422)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""The arguments of the server."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
from typing import List, Optional, Union
|
||||
@@ -5,33 +7,44 @@ from typing import List, Optional, Union
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
# Model and tokenizer
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
additional_ports: Optional[Union[List[int], int]] = None
|
||||
load_format: str = "auto"
|
||||
tokenizer_mode: str = "auto"
|
||||
chat_template: Optional[str] = None
|
||||
trust_remote_code: bool = True
|
||||
context_length: Optional[int] = None
|
||||
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
additional_ports: Optional[Union[List[int], int]] = None
|
||||
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_prefill_num_token: Optional[int] = None
|
||||
context_length: Optional[int] = None
|
||||
tp_size: int = 1
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
attention_reduce_in_fp32: bool = False
|
||||
random_seed: int = 42
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
stream_interval: int = 8
|
||||
random_seed: int = 42
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
api_key: str = ""
|
||||
show_time_cost: bool = False
|
||||
|
||||
# optional modes
|
||||
disable_radix_cache: bool = False
|
||||
# Other
|
||||
api_key: str = ""
|
||||
|
||||
# Optimization/debug options
|
||||
enable_flashinfer: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
|
||||
@@ -66,15 +79,16 @@ class ServerArgs:
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host)
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port)
|
||||
# we want to be able to pass a list of ports
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host,
|
||||
help="The host of the server.")
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port,
|
||||
help="The port of the server.")
|
||||
parser.add_argument(
|
||||
"--additional-ports",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Additional ports specified for launching server.",
|
||||
help="Additional ports specified for the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
@@ -112,6 +126,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=ServerArgs.context_length,
|
||||
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
@@ -124,18 +144,6 @@ class ServerArgs:
|
||||
default=ServerArgs.max_prefill_num_token,
|
||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=ServerArgs.context_length,
|
||||
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
@@ -149,15 +157,10 @@ class ServerArgs:
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-interval",
|
||||
@@ -165,11 +168,17 @@ class ServerArgs:
|
||||
default=ServerArgs.stream_interval,
|
||||
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default=ServerArgs.log_level,
|
||||
help="Log level",
|
||||
help="Logging level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-stats",
|
||||
@@ -182,29 +191,34 @@ class ServerArgs:
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API Key",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-time-cost",
|
||||
action="store_true",
|
||||
help="Show time cost of custom marks",
|
||||
)
|
||||
|
||||
# optional modes
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
action="store_true",
|
||||
help="Disable RadixAttention",
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API key of the server",
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
action="store_true",
|
||||
help="Enable flashinfer inference kernels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
action="store_true",
|
||||
help="Disable RadixAttention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-regex-jump-forward",
|
||||
action="store_true",
|
||||
@@ -224,13 +238,13 @@ class ServerArgs:
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def get_optional_modes_logging(self):
|
||||
def print_mode_args(self):
|
||||
return (
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
)
|
||||
|
||||
|
||||
@@ -240,4 +254,4 @@ class PortArgs:
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
model_rpc_ports: List[int]
|
||||
Reference in New Issue
Block a user