Support multi-node DP attention (#2925)
Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
@@ -239,15 +239,14 @@ class ServerArgs:
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.dp_size = self.tp_size
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||
"Overlap scheduler is disabled."
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
@@ -880,8 +879,8 @@ class ServerArgs:
|
||||
self.tp_size % self.nnodes == 0
|
||||
), "tp_size must be divisible by number of nodes"
|
||||
assert not (
|
||||
self.dp_size > 1 and self.nnodes != 1
|
||||
), "multi-node data parallel is not supported"
|
||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||
), "multi-node data parallel is not supported unless dp attention!"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
@@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
return server_args
|
||||
|
||||
|
||||
ZMQ_TCP_PORT_DELTA = 233
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
||||
@@ -932,7 +934,7 @@ class PortArgs:
|
||||
nccl_port: int
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args) -> "PortArgs":
|
||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||
port = server_args.port + random.randint(100, 1000)
|
||||
while True:
|
||||
if is_port_available(port):
|
||||
@@ -942,12 +944,39 @@ class PortArgs:
|
||||
else:
|
||||
port -= 43
|
||||
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
nccl_port=port,
|
||||
)
|
||||
if not server_args.enable_dp_attention:
|
||||
# Normal case, use IPC within a single node
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
nccl_port=port,
|
||||
)
|
||||
else:
|
||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
||||
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
||||
else:
|
||||
dist_init_addr = server_args.dist_init_addr.split(":")
|
||||
assert (
|
||||
len(dist_init_addr) == 2
|
||||
), "please provide --dist-init-addr as host:port of head node"
|
||||
|
||||
dist_init_host, dist_init_port = dist_init_addr
|
||||
port_base = int(dist_init_port) + 1
|
||||
if dp_rank is None:
|
||||
scheduler_input_port = (
|
||||
port_base + 2
|
||||
) # TokenizerManager to DataParallelController
|
||||
else:
|
||||
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
||||
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
||||
nccl_port=port,
|
||||
)
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
|
||||
Reference in New Issue
Block a user