minor refactor: move check server args to server_args.py (#774)
This commit is contained in:
@@ -176,6 +176,8 @@ def launch_server(
|
|||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args: Optional[dict] = None,
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||||
):
|
):
|
||||||
|
server_args.check_server_args()
|
||||||
|
|
||||||
"""Launch an HTTP server."""
|
"""Launch an HTTP server."""
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|
||||||
@@ -230,8 +232,6 @@ def launch_server(
|
|||||||
|
|
||||||
# Handle multi-node tensor parallelism
|
# Handle multi-node tensor parallelism
|
||||||
if server_args.nnodes > 1:
|
if server_args.nnodes > 1:
|
||||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
|
||||||
|
|
||||||
if server_args.node_rank != 0:
|
if server_args.node_rank != 0:
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
gpu_ids = [
|
gpu_ids = [
|
||||||
|
|||||||
@@ -364,6 +364,14 @@ class ServerArgs:
|
|||||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_server_args(self):
|
||||||
|
assert (
|
||||||
|
self.tp_size % self.nnodes == 0
|
||||||
|
), "tp_size must be divisible by number of nodes"
|
||||||
|
assert not (
|
||||||
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
|
), "multi-node data parallel is not supported"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
|
|||||||
Reference in New Issue
Block a user