Multi-node Tensor Parallelism (#550)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-06-17 20:41:24 -07:00
committed by GitHub
parent 53a7ebd89a
commit 09593e9bc9
10 changed files with 167 additions and 46 deletions

View File

@@ -56,6 +56,11 @@ class ServerArgs:
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@@ -252,6 +257,24 @@ class ServerArgs:
],
)
# Multi-node distributed serving args
parser.add_argument(
"--nccl-init-addr",
type=str,
help="The nccl init address of multi-node server."
)
parser.add_argument(
"--nnodes",
type=int,
default=1,
help="Number of nodes"
)
parser.add_argument(
"--node-rank",
type=int,
help="The node rank."
)
# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
@@ -300,6 +323,7 @@ class ServerArgs:
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ips: List[str]
model_tp_ports: List[int]