Remove the dependency of rpyc (#646)
This commit is contained in:
@@ -21,7 +21,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
||||
"psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
|
||||
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
|
||||
@@ -11,4 +11,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
launch_server(server_args)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
@@ -27,6 +26,4 @@ if __name__ == "__main__":
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
|
||||
launch_server(server_args, pipe_writer, model_overide_args)
|
||||
launch_server(server_args, model_overide_args, None)
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
"""A data parallel worker thread."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import Callable, List
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
CHECKING_INTERVAL = 5
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class DataParallelWorkerThread(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
request_queue: queue.Queue,
|
||||
detokenizer_port: int,
|
||||
step_func: Callable,
|
||||
):
|
||||
super(DataParallelWorkerThread, self).__init__()
|
||||
self.worker_id = worker_id
|
||||
self.request_queue = request_queue
|
||||
self.liveness = True
|
||||
self.request_dependency_delay = global_config.request_dependency_delay
|
||||
|
||||
context = zmq.asyncio.Context()
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
||||
|
||||
self.step = step_func
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while self.liveness:
|
||||
requests = []
|
||||
while not self.request_queue.empty():
|
||||
requests.append(self.request_queue.get())
|
||||
|
||||
out_pyobjs: List[BatchTokenIDOut] = []
|
||||
try:
|
||||
out_pyobjs = await self.step(requests)
|
||||
except Exception:
|
||||
for r in requests:
|
||||
self.request_queue.put(r)
|
||||
logger.error(
|
||||
f"Worker thread {self.worker_id}: "
|
||||
f"failed to get back from Model Server\n"
|
||||
f"{get_exception_traceback()}"
|
||||
)
|
||||
self.liveness = False
|
||||
# Crash the whole server when there are any errors.
|
||||
# TODO(lianmin): make this an option.
|
||||
kill_parent_process()
|
||||
return
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any(
|
||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||
)
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.request_dependency_delay)
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def monitoring(self):
|
||||
while True:
|
||||
await asyncio.sleep(CHECKING_INTERVAL)
|
||||
# can plug in monitoring logic here
|
||||
|
||||
def run(self):
|
||||
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(self.monitoring())
|
||||
loop.run_until_complete(self.loop_for_forward())
|
||||
|
||||
|
||||
def start_data_parallel_worker(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args,
|
||||
gpu_ids: List[int],
|
||||
worker_id: int,
|
||||
):
|
||||
model_tp_client = ModelTpClient(
|
||||
gpu_ids,
|
||||
server_args,
|
||||
port_args.model_port_args[worker_id],
|
||||
model_overide_args,
|
||||
)
|
||||
worker_thread = DataParallelWorkerThread(
|
||||
worker_id=worker_id,
|
||||
request_queue=queue.Queue(),
|
||||
detokenizer_port=port_args.detokenizer_port,
|
||||
step_func=model_tp_client.step,
|
||||
)
|
||||
worker_thread.start()
|
||||
return worker_thread
|
||||
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
|
||||
Each data parallel worker can manage multiple tensor parallel workers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import multiprocessing
|
||||
import os
|
||||
from enum import Enum, auto
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.dp_worker import (
|
||||
DataParallelWorkerThread,
|
||||
start_data_parallel_worker,
|
||||
from sglang.srt.managers.controller.manager_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
"""Load balance method."""
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
class Controller:
|
||||
@dataclasses.dataclass
|
||||
class WorkerHandle:
|
||||
"""Store the handle of a data parallel worker."""
|
||||
proc: multiprocessing.Process
|
||||
queue: multiprocessing.Queue
|
||||
|
||||
|
||||
class ControllerMulti:
|
||||
"""A controller that manages multiple data parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_balance_method: str,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.model_overide_args = model_overide_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method)
|
||||
|
||||
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
||||
self.round_robin_counter = 0
|
||||
# Init communication
|
||||
context = zmq.Context()
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
self.dispatch_lookup = {
|
||||
# Dispatch method
|
||||
self.round_robin_counter = 0
|
||||
dispatch_lookup = {
|
||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||
}
|
||||
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context()
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
# Init status
|
||||
self.recv_reqs = []
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Start data parallel workers
|
||||
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
def start_dp_worker(i):
|
||||
try:
|
||||
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
||||
worker_thread = start_data_parallel_worker(
|
||||
server_args, port_args, model_overide_args, gpu_ids, i
|
||||
)
|
||||
self.workers[i] = worker_thread
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
||||
)
|
||||
|
||||
self.workers = []
|
||||
for i in range(server_args.dp_size):
|
||||
start_dp_worker(i)
|
||||
self.start_dp_worker(i)
|
||||
|
||||
# Parallel launch is slower, probably due to the disk bandwidth limitations.
|
||||
# with ThreadPoolExecutor(server_args.dp_size) as executor:
|
||||
# executor.map(start_dp_worker, range(server_args.dp_size))
|
||||
def start_dp_worker(self, dp_worker_id: int):
|
||||
tp_size = self.server_args.tp_size
|
||||
|
||||
def have_any_live_worker(self):
|
||||
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
||||
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
def put_req_to_worker(self, worker_id, req):
|
||||
self.workers[worker_id].request_queue.put(req)
|
||||
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
||||
queue = multiprocessing.Queue()
|
||||
proc = multiprocessing.Process(
|
||||
target=start_controller_process_single,
|
||||
args=(
|
||||
self.server_args,
|
||||
self.port_args,
|
||||
pipe_controller_writer,
|
||||
self.model_overide_args,
|
||||
True,
|
||||
gpu_ids,
|
||||
dp_worker_id,
|
||||
queue,
|
||||
)
|
||||
)
|
||||
proc.start()
|
||||
|
||||
async def round_robin_scheduler(self, input_requests):
|
||||
available_workers = list(self.workers.keys())
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
if controller_init_state != "init ok":
|
||||
raise RuntimeError(
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}"
|
||||
)
|
||||
self.workers.append(WorkerHandle(
|
||||
proc=proc,
|
||||
queue=queue,
|
||||
))
|
||||
|
||||
def round_robin_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
||||
self.workers[self.round_robin_counter].queue.put(r)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
available_workers
|
||||
self.workers
|
||||
)
|
||||
return
|
||||
|
||||
async def shortest_queue_scheduler(self, input_requests):
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
worker = min(
|
||||
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
||||
)
|
||||
self.put_req_to_worker(worker, r)
|
||||
return
|
||||
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
||||
wid = np.argmin(queue_sizes)
|
||||
self.workers[wid].queue.put(r)
|
||||
|
||||
async def remove_dead_workers(self):
|
||||
for i in list(self.workers.keys()):
|
||||
worker_thread = self.workers[i]
|
||||
if not worker_thread.liveness:
|
||||
worker_thread.join()
|
||||
# move unsuccessful requests back to the queue
|
||||
while not worker_thread.request_queue.empty():
|
||||
self.recv_reqs.append(worker_thread.request_queue.get())
|
||||
del self.workers[i]
|
||||
logger.info(f"Stale worker {i} removed")
|
||||
|
||||
async def loop_for_forward(self):
|
||||
def loop_for_forward(self):
|
||||
while True:
|
||||
await self.remove_dead_workers()
|
||||
recv_reqs = self.recv_requests()
|
||||
self.dispatching(recv_reqs)
|
||||
|
||||
if self.have_any_live_worker():
|
||||
next_step_input = list(self.recv_reqs)
|
||||
self.recv_reqs = []
|
||||
if next_step_input:
|
||||
await self.dispatching(next_step_input)
|
||||
# else:
|
||||
# logger.error("There is no live worker.")
|
||||
def recv_requests(self):
|
||||
recv_reqs = []
|
||||
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
|
||||
if isinstance(recv_req, FlushCacheReq):
|
||||
# TODO(lsyin): apply more specific flushCacheReq
|
||||
for worker_thread in self.workers.values():
|
||||
worker_thread.request_queue.put(recv_req)
|
||||
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.recv_reqs.append(recv_req)
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
in_queue = False
|
||||
for i, req in enumerate(self.recv_reqs):
|
||||
for i, req in enumerate(recv_reqs):
|
||||
if req.rid == recv_req.rid:
|
||||
self.recv_reqs[i] = recv_req
|
||||
recv_reqs[i] = recv_req
|
||||
in_queue = True
|
||||
break
|
||||
if not in_queue:
|
||||
# Send abort req to all TP groups
|
||||
for worker in list(self.workers.keys()):
|
||||
self.put_req_to_worker(worker, recv_req)
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
recv_reqs.append(recv_req)
|
||||
else:
|
||||
logger.error(f"Invalid object: {recv_req}")
|
||||
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
model_overide_args=None,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
controller = Controller(
|
||||
server_args.load_balance_method, server_args, port_args, model_overide_args
|
||||
)
|
||||
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
try:
|
||||
controller.loop_for_forward()
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
||||
finally:
|
||||
for w in controller.workers:
|
||||
os.kill(w.proc.pid, 9)
|
||||
kill_parent_process()
|
||||
|
||||
@@ -3,126 +3,61 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
||||
from sglang.srt.managers.controller.tp_worker import (
|
||||
broadcast_recv_input, launch_tp_servers, ModelTpServer
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
|
||||
|
||||
def run_tp_server(
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Run a tp server."""
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
while True:
|
||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||
model_server.exposed_step(recv_reqs)
|
||||
except Exception:
|
||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
|
||||
def launch_tp_servers(
|
||||
gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args
|
||||
):
|
||||
"""Launch multiple tp servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_tp_server,
|
||||
args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
return procs
|
||||
|
||||
|
||||
def broadcast_recv_input(data, rank, dist_group):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
|
||||
class ControllerSingle:
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args: dict,
|
||||
gpu_ids: List[int],
|
||||
is_data_parallel_worker: bool,
|
||||
dp_worker_id: int,
|
||||
mp_queue: multiprocessing.Queue,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.tp_procs = []
|
||||
self.tp_size = server_args.tp_size
|
||||
self.is_dp_worker = is_data_parallel_worker
|
||||
self.dp_worker_id = dp_worker_id
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
# Init communication
|
||||
context = zmq.Context(2)
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
if not self.is_dp_worker:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
|
||||
# Init model server
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
|
||||
# Launch other tp ranks
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
self.tp_procs = []
|
||||
if tp_size_local > 1:
|
||||
tp_rank_range = range(1, tp_size_local)
|
||||
self.tp_procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
@@ -131,16 +66,19 @@ class ControllerSingle:
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
model_overide_args,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||
|
||||
def loop_for_forward(self):
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
if not self.is_dp_worker:
|
||||
recv_reqs = self.recv_requests_from_zmq()
|
||||
else:
|
||||
recv_reqs = self.recv_requests_from_mp_queue()
|
||||
|
||||
if self.server_args.tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
||||
|
||||
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||
@@ -148,27 +86,51 @@ class ControllerSingle:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
def recv_requests(self):
|
||||
def recv_requests_from_zmq(self):
|
||||
recv_reqs = []
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
recv_reqs.append(recv_req)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
return recv_reqs
|
||||
|
||||
def recv_requests_from_mp_queue(self):
|
||||
recv_reqs = []
|
||||
while not self.mp_queue.empty():
|
||||
recv_reqs.append(self.mp_queue.get())
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer: multiprocessing.connection.Connection,
|
||||
model_overide_args: dict,
|
||||
is_data_parallel_worker: bool = False,
|
||||
gpu_ids: List[int] = None,
|
||||
dp_worker_id: int = None,
|
||||
queue: multiprocessing.connection.Connection = None,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
if not is_data_parallel_worker:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
dp_worker_id = 0
|
||||
queue = None
|
||||
|
||||
try:
|
||||
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
||||
controller = ControllerSingle(server_args, port_args, model_overide_args,
|
||||
gpu_ids, is_data_parallel_worker,
|
||||
dp_worker_id, queue)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""A tensor parallel worker."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
connect_rpyc_service,
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
start_rpyc_service_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -52,10 +49,9 @@ class ModelTpServer:
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
nccl_port: int,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
suppress_other_loggers()
|
||||
|
||||
# Copy arguments
|
||||
@@ -79,7 +75,7 @@ class ModelTpServer:
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=model_port_args.nccl_port,
|
||||
nccl_port=nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
@@ -178,9 +174,6 @@ class ModelTpServer:
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if not isinstance(recv_reqs, list):
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
@@ -425,12 +418,6 @@ class ModelTpServer:
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
# )
|
||||
|
||||
# Return the new batch
|
||||
new_batch = Batch.init_new(
|
||||
@@ -733,87 +720,74 @@ class ModelTpServer:
|
||||
break
|
||||
|
||||
|
||||
class ModelTpService(rpyc.Service):
|
||||
exposed_ModelTpServer = ModelTpServer
|
||||
def run_tp_server(
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Run a tensor parallel server."""
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
server_args,
|
||||
nccl_port,
|
||||
model_overide_args,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
while True:
|
||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||
model_server.exposed_step(recv_reqs)
|
||||
except Exception:
|
||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
|
||||
class ModelTpClient:
|
||||
def __init__(
|
||||
self,
|
||||
gpu_ids: List[int],
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
self.tp_size = server_args.tp_size
|
||||
def launch_tp_servers(
|
||||
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
||||
):
|
||||
"""Launch multiple tensor parallel servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_tp_server,
|
||||
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
if self.tp_size * server_args.dp_size == 1:
|
||||
# Init model
|
||||
assert len(gpu_ids) == 1
|
||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
return procs
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(f):
|
||||
async def _func(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
def broadcast_recv_input(data, rank, dist_group):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
self.step = async_wrap(self.model_server.exposed_step)
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||
# Launch model processes
|
||||
if server_args.nnodes == 1:
|
||||
self.procs = list(
|
||||
executor.map(
|
||||
lambda args: start_rpyc_service_process(*args),
|
||||
[
|
||||
(ModelTpService, p)
|
||||
for p in model_port_args.model_tp_ports
|
||||
],
|
||||
)
|
||||
)
|
||||
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
||||
else:
|
||||
addrs = [
|
||||
(ip, port)
|
||||
for ip, port in zip(
|
||||
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
||||
)
|
||||
]
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
self.model_services = list(
|
||||
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
||||
)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.model_services[i].ModelTpServer(
|
||||
gpu_ids[i],
|
||||
i,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
||||
|
||||
async def _func(*args, **kwargs):
|
||||
tasks = [f(*args, **kwargs) for f in fs]
|
||||
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
||||
return obtain(tasks[0].value)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
@@ -61,7 +61,7 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
self.model_path = server_args.model_path
|
||||
self.hf_config = get_config(
|
||||
|
||||
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
)
|
||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware,
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
receive_addrs,
|
||||
send_addrs_to_rank_0,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -98,6 +96,7 @@ async def flush_cache():
|
||||
|
||||
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
|
||||
}
|
||||
|
||||
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
||||
def launch_server(server_args: ServerArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection] = None):
|
||||
"""Launch an HTTP server."""
|
||||
global tokenizer_manager
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
_set_global_server_args(server_args)
|
||||
|
||||
# Allocate ports
|
||||
assert server_args.tp_size % server_args.nnodes == 0
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
tp_size_local,
|
||||
server_args.dp_size,
|
||||
)
|
||||
|
||||
ports = server_args.additional_ports
|
||||
model_port_args = []
|
||||
for i in range(server_args.dp_size):
|
||||
model_port_args.append(
|
||||
ModelPortArgs(
|
||||
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
||||
model_tp_ips=[None] * tp_size_local,
|
||||
model_tp_ports=ports[
|
||||
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
router_port=ports[1],
|
||||
controller_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
model_port_args=model_port_args,
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
|
||||
# Handle multi-node tp
|
||||
# Handle multi-node tensor parallelism
|
||||
if server_args.nnodes > 1:
|
||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||
|
||||
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
while True:
|
||||
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_process = start_controller_process_single
|
||||
else:
|
||||
start_process = start_controller_process_multi
|
||||
proc_router = mp.Process(
|
||||
proc_controller = mp.Process(
|
||||
target=start_process,
|
||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||
)
|
||||
proc_router.start()
|
||||
proc_controller.start()
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
proc_detoken.start()
|
||||
|
||||
# Wait for the model to finish loading
|
||||
router_init_state = pipe_router_reader.recv()
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
print(
|
||||
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
|
||||
)
|
||||
print(
|
||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||
flush=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
# Send a warmup request
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
# Wait until the server is launched
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# Send a warmup request
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}", flush=True)
|
||||
raise e
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
|
||||
t = threading.Thread(target=_wait_and_warmup)
|
||||
t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
|
||||
t.start()
|
||||
|
||||
# Listen for requests
|
||||
@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
t.join()
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
# Wait until the server is launched
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# Send a warmup request
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}", flush=True)
|
||||
raise e
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the server.
|
||||
@@ -354,7 +341,6 @@ class Runtime:
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port,
|
||||
self.server_args.additional_ports,
|
||||
self.server_args.tp_size,
|
||||
self.server_args.dp_size,
|
||||
)
|
||||
|
||||
@@ -367,7 +353,7 @@ class Runtime:
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer, model_overide_args),
|
||||
args=(self.server_args, model_overide_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
|
||||
@@ -337,16 +337,9 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelPortArgs:
|
||||
nccl_port: int
|
||||
model_tp_ips: List[str]
|
||||
model_tp_ports: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
tokenizer_port: int
|
||||
router_port: int
|
||||
controller_port: int
|
||||
detokenizer_port: int
|
||||
model_port_args: List[ModelPortArgs]
|
||||
nccl_ports: List[int]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
@@ -16,12 +15,10 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import rpyc
|
||||
import torch
|
||||
import triton
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,7 +145,6 @@ def is_port_available(port):
|
||||
def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
"""Allocate ports for all connections."""
|
||||
@@ -160,8 +156,8 @@ def allocate_init_ports(
|
||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
||||
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
||||
num_ports_needed = 4 + dp_size
|
||||
while len(ret_ports) < num_ports_needed:
|
||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||
ret_ports.append(cur_port)
|
||||
@@ -371,49 +367,6 @@ def load_image(image_file):
|
||||
return image, image_size
|
||||
|
||||
|
||||
def connect_rpyc_service(host, port):
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
host,
|
||||
port,
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError as e:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError(f"Connect rpyc error: {e}")
|
||||
|
||||
return con.root
|
||||
|
||||
|
||||
def start_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
||||
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
|
||||
Reference in New Issue
Block a user