Support data parallelism (static) (#480)

Co-authored-by: Ying Sheng <ying.sheng@databricks.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Ying Sheng
2024-05-27 21:24:10 -07:00
committed by GitHub
parent 565d727409
commit 0463f7fb52
32 changed files with 580 additions and 181 deletions

View File

@@ -44,6 +44,10 @@ class ServerArgs:
# Other
api_key: str = ""
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Optimization/debug options
enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False
@@ -226,6 +230,24 @@ class ServerArgs:
help="Set API key of the server",
)
# Data parallelism
parser.add_argument(
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="Data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="Load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
@@ -271,10 +293,15 @@ class ServerArgs:
)
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ports: List[int]
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
model_port_args: List[ModelPortArgs]