Files
sglang/python/sglang/srt/managers/controller/manager_single.py
2024-07-18 03:33:21 -07:00

158 lines
4.5 KiB
Python

"""A controller that manages a group of tensor parallel workers."""
import logging
import multiprocessing
import os
from typing import List
import zmq
from sglang.srt.managers.controller.tp_worker import (
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
class ControllerSingle:
"""A controller that manages a group of tensor parallel workers."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: dict,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
):
# Parse args
self.tp_size = server_args.tp_size
self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
# Init communication
context = zmq.Context(2)
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
# Launch other tp ranks
tp_size_local = server_args.tp_size // server_args.nnodes
self.tp_procs = []
if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
port_args.nccl_ports[dp_worker_id],
model_overide_args,
)
# Launch tp rank 0
self.tp_server = ModelTpServer(
gpu_ids[0],
0,
server_args,
port_args.nccl_ports[dp_worker_id],
model_overide_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
def loop_for_forward(self):
while True:
if not self.is_dp_worker:
recv_reqs = self.recv_requests_from_zmq()
else:
recv_reqs = self.recv_requests_from_mp_queue()
if self.tp_size > 1:
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
def recv_requests_from_zmq(self):
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
return recv_reqs
def recv_requests_from_mp_queue(self):
recv_reqs = []
while not self.mp_queue.empty():
recv_reqs.append(self.mp_queue.get())
return recv_reqs
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection,
model_overide_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: multiprocessing.connection.Connection = None,
):
"""Start a controller process."""
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None
try:
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process()