Fix zmq binding (#2930)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
This commit is contained in:
@@ -66,7 +66,7 @@ class DataParallelController:
|
|||||||
self.context = zmq.Context(1 + server_args.dp_size)
|
self.context = zmq.Context(1 + server_args.dp_size)
|
||||||
if server_args.node_rank == 0:
|
if server_args.node_rank == 0:
|
||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
|
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dispatch method
|
# Dispatch method
|
||||||
@@ -93,6 +93,7 @@ class DataParallelController:
|
|||||||
self.context,
|
self.context,
|
||||||
zmq.PUSH,
|
zmq.PUSH,
|
||||||
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def launch_dp_schedulers(self, server_args, port_args):
|
def launch_dp_schedulers(self, server_args, port_args):
|
||||||
|
|||||||
@@ -58,10 +58,10 @@ class DetokenizerManager:
|
|||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
self.recv_from_scheduler = get_zmq_socket(
|
self.recv_from_scheduler = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.detokenizer_ipc_name
|
context, zmq.PULL, port_args.detokenizer_ipc_name, True
|
||||||
)
|
)
|
||||||
self.send_to_tokenizer = get_zmq_socket(
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
|
|||||||
@@ -162,21 +162,21 @@ class Scheduler:
|
|||||||
|
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||||
)
|
)
|
||||||
self.send_to_tokenizer = get_zmq_socket(
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the TokenizerManager
|
# Directly send to the TokenizerManager
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Send to the DetokenizerManager
|
# Send to the DetokenizerManager
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.recv_from_tokenizer = None
|
self.recv_from_tokenizer = None
|
||||||
|
|||||||
@@ -119,10 +119,10 @@ class TokenizerManager:
|
|||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
self.recv_from_detokenizer = get_zmq_socket(
|
self.recv_from_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.tokenizer_ipc_name
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||||
)
|
)
|
||||||
self.send_to_scheduler = get_zmq_socket(
|
self.send_to_scheduler = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
|
|||||||
@@ -789,7 +789,9 @@ def first_rank_print(*args, **kwargs):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
|
def get_zmq_socket(
|
||||||
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
||||||
|
):
|
||||||
mem = psutil.virtual_memory()
|
mem = psutil.virtual_memory()
|
||||||
total_mem = mem.total / 1024**3
|
total_mem = mem.total / 1024**3
|
||||||
available_mem = mem.available / 1024**3
|
available_mem = mem.available / 1024**3
|
||||||
@@ -802,14 +804,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
|||||||
if socket_type == zmq.PUSH:
|
if socket_type == zmq.PUSH:
|
||||||
socket.setsockopt(zmq.SNDHWM, 0)
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||||
socket.connect(endpoint)
|
|
||||||
elif socket_type == zmq.PULL:
|
elif socket_type == zmq.PULL:
|
||||||
socket.setsockopt(zmq.RCVHWM, 0)
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||||
socket.bind(endpoint)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|
||||||
|
if bind:
|
||||||
|
socket.bind(endpoint)
|
||||||
|
else:
|
||||||
|
socket.connect(endpoint)
|
||||||
|
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user