diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 762db5322..d65379b94 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -176,6 +176,8 @@ def launch_server( model_overide_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): + server_args.check_server_args() + """Launch an HTTP server.""" global tokenizer_manager @@ -230,8 +232,6 @@ def launch_server( # Handle multi-node tensor parallelism if server_args.nnodes > 1: - assert server_args.dp_size == 1, "Multi-node dp is not supported." - if server_args.node_rank != 0: tp_size_local = server_args.tp_size // server_args.nnodes gpu_ids = [ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 477b6342d..d487cd7b8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -364,6 +364,14 @@ class ServerArgs: f"disable_disk_cache={self.disable_disk_cache}, " ) + def check_server_args(self): + assert ( + self.tp_size % self.nnodes == 0 + ), "tp_size must be divisible by number of nodes" + assert not ( + self.dp_size > 1 and self.node_rank is not None + ), "multi-node data parallel is not supported" + @dataclasses.dataclass class PortArgs: