Multi-node Tensor Parallelism (#550)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -56,6 +56,11 @@ class ServerArgs:
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -252,6 +257,24 @@ class ServerArgs:
|
||||
],
|
||||
)
|
||||
|
||||
# Multi-node distributed serving args
|
||||
parser.add_argument(
|
||||
"--nccl-init-addr",
|
||||
type=str,
|
||||
help="The nccl init address of multi-node server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nnodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of nodes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-rank",
|
||||
type=int,
|
||||
help="The node rank."
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
@@ -300,6 +323,7 @@ class ServerArgs:
|
||||
@dataclasses.dataclass
|
||||
class ModelPortArgs:
|
||||
nccl_port: int
|
||||
model_tp_ips: List[str]
|
||||
model_tp_ports: List[int]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user