Remove the dependency of rpyc (#646)
This commit is contained in:
@@ -21,7 +21,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
||||||
"psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
|
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
|
||||||
openai = ["openai>=1.0", "tiktoken"]
|
openai = ["openai>=1.0", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
|
|||||||
@@ -11,4 +11,4 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
launch_server(server_args, None)
|
launch_server(server_args)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Launch the inference server for Llava-video model."""
|
"""Launch the inference server for Llava-video model."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import multiprocessing as mp
|
|
||||||
|
|
||||||
from sglang.srt.server import ServerArgs, launch_server
|
from sglang.srt.server import ServerArgs, launch_server
|
||||||
|
|
||||||
@@ -27,6 +26,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
launch_server(server_args, model_overide_args, None)
|
||||||
|
|
||||||
launch_server(server_args, pipe_writer, model_overide_args)
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
"""A data parallel worker thread."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
from typing import Callable, List
|
|
||||||
|
|
||||||
import uvloop
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
|
||||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
||||||
from sglang.srt.utils import kill_parent_process
|
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
|
|
||||||
logger = logging.getLogger("srt.controller")
|
|
||||||
CHECKING_INTERVAL = 5
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
||||||
|
|
||||||
|
|
||||||
class DataParallelWorkerThread(threading.Thread):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
worker_id: int,
|
|
||||||
request_queue: queue.Queue,
|
|
||||||
detokenizer_port: int,
|
|
||||||
step_func: Callable,
|
|
||||||
):
|
|
||||||
super(DataParallelWorkerThread, self).__init__()
|
|
||||||
self.worker_id = worker_id
|
|
||||||
self.request_queue = request_queue
|
|
||||||
self.liveness = True
|
|
||||||
self.request_dependency_delay = global_config.request_dependency_delay
|
|
||||||
|
|
||||||
context = zmq.asyncio.Context()
|
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
|
||||||
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
|
||||||
|
|
||||||
self.step = step_func
|
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
|
||||||
while self.liveness:
|
|
||||||
requests = []
|
|
||||||
while not self.request_queue.empty():
|
|
||||||
requests.append(self.request_queue.get())
|
|
||||||
|
|
||||||
out_pyobjs: List[BatchTokenIDOut] = []
|
|
||||||
try:
|
|
||||||
out_pyobjs = await self.step(requests)
|
|
||||||
except Exception:
|
|
||||||
for r in requests:
|
|
||||||
self.request_queue.put(r)
|
|
||||||
logger.error(
|
|
||||||
f"Worker thread {self.worker_id}: "
|
|
||||||
f"failed to get back from Model Server\n"
|
|
||||||
f"{get_exception_traceback()}"
|
|
||||||
)
|
|
||||||
self.liveness = False
|
|
||||||
# Crash the whole server when there are any errors.
|
|
||||||
# TODO(lianmin): make this an option.
|
|
||||||
kill_parent_process()
|
|
||||||
return
|
|
||||||
|
|
||||||
for obj in out_pyobjs:
|
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
|
||||||
|
|
||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
|
||||||
if len(out_pyobjs) != 0:
|
|
||||||
has_finished = any(
|
|
||||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
|
||||||
)
|
|
||||||
if has_finished:
|
|
||||||
await asyncio.sleep(self.request_dependency_delay)
|
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
|
||||||
|
|
||||||
async def monitoring(self):
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(CHECKING_INTERVAL)
|
|
||||||
# can plug in monitoring logic here
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.create_task(self.monitoring())
|
|
||||||
loop.run_until_complete(self.loop_for_forward())
|
|
||||||
|
|
||||||
|
|
||||||
def start_data_parallel_worker(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
port_args: PortArgs,
|
|
||||||
model_overide_args,
|
|
||||||
gpu_ids: List[int],
|
|
||||||
worker_id: int,
|
|
||||||
):
|
|
||||||
model_tp_client = ModelTpClient(
|
|
||||||
gpu_ids,
|
|
||||||
server_args,
|
|
||||||
port_args.model_port_args[worker_id],
|
|
||||||
model_overide_args,
|
|
||||||
)
|
|
||||||
worker_thread = DataParallelWorkerThread(
|
|
||||||
worker_id=worker_id,
|
|
||||||
request_queue=queue.Queue(),
|
|
||||||
detokenizer_port=port_args.detokenizer_port,
|
|
||||||
step_func=model_tp_client.step,
|
|
||||||
)
|
|
||||||
worker_thread.start()
|
|
||||||
return worker_thread
|
|
||||||
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
|
|||||||
Each data parallel worker can manage multiple tensor parallel workers.
|
Each data parallel worker can manage multiple tensor parallel workers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
import multiprocessing
|
||||||
|
import os
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.srt.managers.controller.manager_single import (
|
||||||
from sglang.srt.managers.controller.dp_worker import (
|
start_controller_process as start_controller_process_single,
|
||||||
DataParallelWorkerThread,
|
|
||||||
start_data_parallel_worker,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import kill_parent_process
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger("srt.controller")
|
logger = logging.getLogger("srt.controller")
|
||||||
|
|
||||||
|
|
||||||
class LoadBalanceMethod(Enum):
|
class LoadBalanceMethod(Enum):
|
||||||
|
"""Load balance method."""
|
||||||
ROUND_ROBIN = auto()
|
ROUND_ROBIN = auto()
|
||||||
SHORTEST_QUEUE = auto()
|
SHORTEST_QUEUE = auto()
|
||||||
|
|
||||||
@@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum):
|
|||||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
@dataclasses.dataclass
|
||||||
|
class WorkerHandle:
|
||||||
|
"""Store the handle of a data parallel worker."""
|
||||||
|
proc: multiprocessing.Process
|
||||||
|
queue: multiprocessing.Queue
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerMulti:
|
||||||
"""A controller that manages multiple data parallel workers."""
|
"""A controller that manages multiple data parallel workers."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
load_balance_method: str,
|
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
):
|
):
|
||||||
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.port_args = port_args
|
self.port_args = port_args
|
||||||
|
self.model_overide_args = model_overide_args
|
||||||
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||||
|
server_args.load_balance_method)
|
||||||
|
|
||||||
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
# Init communication
|
||||||
self.round_robin_counter = 0
|
context = zmq.Context()
|
||||||
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
|
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||||
|
|
||||||
self.dispatch_lookup = {
|
# Dispatch method
|
||||||
|
self.round_robin_counter = 0
|
||||||
|
dispatch_lookup = {
|
||||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||||
}
|
}
|
||||||
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||||
|
|
||||||
# Init communication
|
|
||||||
context = zmq.asyncio.Context()
|
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
|
||||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
|
||||||
|
|
||||||
# Init status
|
|
||||||
self.recv_reqs = []
|
|
||||||
|
|
||||||
# Start data parallel workers
|
# Start data parallel workers
|
||||||
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
self.workers = []
|
||||||
tp_size = server_args.tp_size
|
|
||||||
|
|
||||||
def start_dp_worker(i):
|
|
||||||
try:
|
|
||||||
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
|
||||||
worker_thread = start_data_parallel_worker(
|
|
||||||
server_args, port_args, model_overide_args, gpu_ids, i
|
|
||||||
)
|
|
||||||
self.workers[i] = worker_thread
|
|
||||||
except Exception:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(server_args.dp_size):
|
for i in range(server_args.dp_size):
|
||||||
start_dp_worker(i)
|
self.start_dp_worker(i)
|
||||||
|
|
||||||
# Parallel launch is slower, probably due to the disk bandwidth limitations.
|
def start_dp_worker(self, dp_worker_id: int):
|
||||||
# with ThreadPoolExecutor(server_args.dp_size) as executor:
|
tp_size = self.server_args.tp_size
|
||||||
# executor.map(start_dp_worker, range(server_args.dp_size))
|
|
||||||
|
|
||||||
def have_any_live_worker(self):
|
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
|
||||||
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
|
||||||
|
|
||||||
def put_req_to_worker(self, worker_id, req):
|
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
||||||
self.workers[worker_id].request_queue.put(req)
|
queue = multiprocessing.Queue()
|
||||||
|
proc = multiprocessing.Process(
|
||||||
|
target=start_controller_process_single,
|
||||||
|
args=(
|
||||||
|
self.server_args,
|
||||||
|
self.port_args,
|
||||||
|
pipe_controller_writer,
|
||||||
|
self.model_overide_args,
|
||||||
|
True,
|
||||||
|
gpu_ids,
|
||||||
|
dp_worker_id,
|
||||||
|
queue,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
|
||||||
async def round_robin_scheduler(self, input_requests):
|
controller_init_state = pipe_controller_reader.recv()
|
||||||
available_workers = list(self.workers.keys())
|
if controller_init_state != "init ok":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Initialization failed. controller_init_state: {controller_init_state}"
|
||||||
|
)
|
||||||
|
self.workers.append(WorkerHandle(
|
||||||
|
proc=proc,
|
||||||
|
queue=queue,
|
||||||
|
))
|
||||||
|
|
||||||
|
def round_robin_scheduler(self, input_requests):
|
||||||
for r in input_requests:
|
for r in input_requests:
|
||||||
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
self.workers[self.round_robin_counter].queue.put(r)
|
||||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||||
available_workers
|
self.workers
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
async def shortest_queue_scheduler(self, input_requests):
|
def shortest_queue_scheduler(self, input_requests):
|
||||||
for r in input_requests:
|
for r in input_requests:
|
||||||
worker = min(
|
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
||||||
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
wid = np.argmin(queue_sizes)
|
||||||
)
|
self.workers[wid].queue.put(r)
|
||||||
self.put_req_to_worker(worker, r)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def remove_dead_workers(self):
|
def loop_for_forward(self):
|
||||||
for i in list(self.workers.keys()):
|
|
||||||
worker_thread = self.workers[i]
|
|
||||||
if not worker_thread.liveness:
|
|
||||||
worker_thread.join()
|
|
||||||
# move unsuccessful requests back to the queue
|
|
||||||
while not worker_thread.request_queue.empty():
|
|
||||||
self.recv_reqs.append(worker_thread.request_queue.get())
|
|
||||||
del self.workers[i]
|
|
||||||
logger.info(f"Stale worker {i} removed")
|
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
|
||||||
while True:
|
while True:
|
||||||
await self.remove_dead_workers()
|
recv_reqs = self.recv_requests()
|
||||||
|
self.dispatching(recv_reqs)
|
||||||
|
|
||||||
if self.have_any_live_worker():
|
def recv_requests(self):
|
||||||
next_step_input = list(self.recv_reqs)
|
recv_reqs = []
|
||||||
self.recv_reqs = []
|
|
||||||
if next_step_input:
|
|
||||||
await self.dispatching(next_step_input)
|
|
||||||
# else:
|
|
||||||
# logger.error("There is no live worker.")
|
|
||||||
|
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
|
||||||
|
|
||||||
async def loop_for_recv_requests(self):
|
|
||||||
while True:
|
while True:
|
||||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
try:
|
||||||
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError:
|
||||||
|
break
|
||||||
|
|
||||||
if isinstance(recv_req, FlushCacheReq):
|
if isinstance(recv_req, FlushCacheReq):
|
||||||
# TODO(lsyin): apply more specific flushCacheReq
|
# TODO(lsyin): apply more specific flushCacheReq
|
||||||
for worker_thread in self.workers.values():
|
for worker in self.workers:
|
||||||
worker_thread.request_queue.put(recv_req)
|
worker.queue.put(recv_req)
|
||||||
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
|
||||||
self.recv_reqs.append(recv_req)
|
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
in_queue = False
|
in_queue = False
|
||||||
for i, req in enumerate(self.recv_reqs):
|
for i, req in enumerate(recv_reqs):
|
||||||
if req.rid == recv_req.rid:
|
if req.rid == recv_req.rid:
|
||||||
self.recv_reqs[i] = recv_req
|
recv_reqs[i] = recv_req
|
||||||
in_queue = True
|
in_queue = True
|
||||||
break
|
break
|
||||||
if not in_queue:
|
if not in_queue:
|
||||||
# Send abort req to all TP groups
|
# Send abort req to all TP groups
|
||||||
for worker in list(self.workers.keys()):
|
for worker in self.workers:
|
||||||
self.put_req_to_worker(worker, recv_req)
|
worker.queue.put(recv_req)
|
||||||
|
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
|
recv_reqs.append(recv_req)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Invalid object: {recv_req}")
|
logger.error(f"Invalid object: {recv_req}")
|
||||||
|
|
||||||
|
return recv_reqs
|
||||||
|
|
||||||
|
|
||||||
def start_controller_process(
|
def start_controller_process(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
model_overide_args=None,
|
model_overide_args: dict,
|
||||||
):
|
):
|
||||||
|
"""Start a controller process."""
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, server_args.log_level.upper()),
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
controller = Controller(
|
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
||||||
server_args.load_balance_method, server_args, port_args, model_overide_args
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
pipe_writer.send("init ok")
|
pipe_writer.send("init ok")
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
try:
|
||||||
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
controller.loop_for_forward()
|
||||||
|
except Exception:
|
||||||
asyncio.set_event_loop(loop)
|
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
||||||
loop.create_task(controller.loop_for_recv_requests())
|
finally:
|
||||||
loop.run_until_complete(controller.loop_for_forward())
|
for w in controller.workers:
|
||||||
|
os.kill(w.proc.pid, 9)
|
||||||
|
kill_parent_process()
|
||||||
|
|||||||
@@ -3,126 +3,61 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
from typing import List
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
from sglang.srt.managers.controller.tp_worker import (
|
||||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
broadcast_recv_input, launch_tp_servers, ModelTpServer
|
||||||
|
)
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import kill_parent_process
|
from sglang.srt.utils import kill_parent_process
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger("srt.controller")
|
logger = logging.getLogger("srt.controller")
|
||||||
|
|
||||||
|
|
||||||
def run_tp_server(
|
|
||||||
gpu_id: int,
|
|
||||||
tp_rank: int,
|
|
||||||
server_args: ServerArgs,
|
|
||||||
model_port_args: ModelPortArgs,
|
|
||||||
model_overide_args: dict,
|
|
||||||
):
|
|
||||||
"""Run a tp server."""
|
|
||||||
try:
|
|
||||||
model_server = ModelTpServer(
|
|
||||||
gpu_id,
|
|
||||||
tp_rank,
|
|
||||||
server_args,
|
|
||||||
model_port_args,
|
|
||||||
model_overide_args,
|
|
||||||
)
|
|
||||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
|
||||||
|
|
||||||
while True:
|
|
||||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
|
||||||
model_server.exposed_step(recv_reqs)
|
|
||||||
except Exception:
|
|
||||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def launch_tp_servers(
|
|
||||||
gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args
|
|
||||||
):
|
|
||||||
"""Launch multiple tp servers."""
|
|
||||||
procs = []
|
|
||||||
for i in tp_rank_range:
|
|
||||||
proc = multiprocessing.Process(
|
|
||||||
target=run_tp_server,
|
|
||||||
args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args),
|
|
||||||
)
|
|
||||||
proc.start()
|
|
||||||
procs.append(proc)
|
|
||||||
|
|
||||||
return procs
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_recv_input(data, rank, dist_group):
|
|
||||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
if len(data) == 0:
|
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
|
||||||
else:
|
|
||||||
serialized_data = pickle.dumps(data)
|
|
||||||
size = len(serialized_data)
|
|
||||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
|
||||||
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
|
||||||
else:
|
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
|
||||||
size = tensor_size.item()
|
|
||||||
|
|
||||||
if size == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
|
||||||
|
|
||||||
serialized_data = bytes(tensor_data.tolist())
|
|
||||||
data = pickle.loads(serialized_data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ControllerSingle:
|
class ControllerSingle:
|
||||||
"""A controller that manages a group of tensor parallel workers."""
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
model_overide_args: dict,
|
||||||
|
gpu_ids: List[int],
|
||||||
|
is_data_parallel_worker: bool,
|
||||||
|
dp_worker_id: int,
|
||||||
|
mp_queue: multiprocessing.Queue,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.tp_size = server_args.tp_size
|
||||||
self.tp_procs = []
|
self.is_dp_worker = is_data_parallel_worker
|
||||||
|
self.dp_worker_id = dp_worker_id
|
||||||
|
self.mp_queue = mp_queue
|
||||||
|
|
||||||
# Init communication
|
# Init communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
|
||||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
if not self.is_dp_worker:
|
||||||
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
|
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||||
|
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||||
self.send_to_detokenizer.connect(
|
self.send_to_detokenizer.connect(
|
||||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init model server
|
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
|
||||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
|
||||||
|
|
||||||
# Launch other tp ranks
|
# Launch other tp ranks
|
||||||
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
|
self.tp_procs = []
|
||||||
if tp_size_local > 1:
|
if tp_size_local > 1:
|
||||||
tp_rank_range = range(1, tp_size_local)
|
tp_rank_range = range(1, tp_size_local)
|
||||||
self.tp_procs = launch_tp_servers(
|
self.tp_procs = launch_tp_servers(
|
||||||
gpu_ids,
|
gpu_ids,
|
||||||
tp_rank_range,
|
tp_rank_range,
|
||||||
server_args,
|
server_args,
|
||||||
port_args.model_port_args[0],
|
port_args.nccl_ports[dp_worker_id],
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,16 +66,19 @@ class ControllerSingle:
|
|||||||
gpu_ids[0],
|
gpu_ids[0],
|
||||||
0,
|
0,
|
||||||
server_args,
|
server_args,
|
||||||
port_args.model_port_args[0],
|
port_args.nccl_ports[dp_worker_id],
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
def loop_for_forward(self):
|
def loop_for_forward(self):
|
||||||
while True:
|
while True:
|
||||||
recv_reqs = self.recv_requests()
|
if not self.is_dp_worker:
|
||||||
|
recv_reqs = self.recv_requests_from_zmq()
|
||||||
|
else:
|
||||||
|
recv_reqs = self.recv_requests_from_mp_queue()
|
||||||
|
|
||||||
if self.server_args.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
||||||
|
|
||||||
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||||
@@ -148,27 +86,51 @@ class ControllerSingle:
|
|||||||
for obj in out_pyobjs:
|
for obj in out_pyobjs:
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests_from_zmq(self):
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
recv_reqs.append(recv_req)
|
|
||||||
except zmq.ZMQError:
|
except zmq.ZMQError:
|
||||||
break
|
break
|
||||||
|
recv_reqs.append(recv_req)
|
||||||
|
|
||||||
|
return recv_reqs
|
||||||
|
|
||||||
|
def recv_requests_from_mp_queue(self):
|
||||||
|
recv_reqs = []
|
||||||
|
while not self.mp_queue.empty():
|
||||||
|
recv_reqs.append(self.mp_queue.get())
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
|
|
||||||
def start_controller_process(
|
def start_controller_process(
|
||||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
pipe_writer: multiprocessing.connection.Connection,
|
||||||
|
model_overide_args: dict,
|
||||||
|
is_data_parallel_worker: bool = False,
|
||||||
|
gpu_ids: List[int] = None,
|
||||||
|
dp_worker_id: int = None,
|
||||||
|
queue: multiprocessing.connection.Connection = None,
|
||||||
):
|
):
|
||||||
|
"""Start a controller process."""
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, server_args.log_level.upper()),
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_data_parallel_worker:
|
||||||
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
|
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||||
|
dp_worker_id = 0
|
||||||
|
queue = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
controller = ControllerSingle(server_args, port_args, model_overide_args,
|
||||||
|
gpu_ids, is_data_parallel_worker,
|
||||||
|
dp_worker_id, queue)
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import rpyc
|
|
||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
connect_rpyc_service,
|
|
||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
start_rpyc_service_process,
|
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -52,10 +49,9 @@ class ModelTpServer:
|
|||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_port_args: ModelPortArgs,
|
nccl_port: int,
|
||||||
model_overide_args: dict,
|
model_overide_args: dict,
|
||||||
):
|
):
|
||||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
|
||||||
# Copy arguments
|
# Copy arguments
|
||||||
@@ -79,7 +75,7 @@ class ModelTpServer:
|
|||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=model_port_args.nccl_port,
|
nccl_port=nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,9 +174,6 @@ class ModelTpServer:
|
|||||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if not isinstance(recv_reqs, list):
|
|
||||||
recv_reqs = obtain(recv_reqs)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Recv requests
|
# Recv requests
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
@@ -425,12 +418,6 @@ class ModelTpServer:
|
|||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||||
)
|
)
|
||||||
# logger.debug(
|
|
||||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
|
||||||
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
|
||||||
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
|
||||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Return the new batch
|
# Return the new batch
|
||||||
new_batch = Batch.init_new(
|
new_batch = Batch.init_new(
|
||||||
@@ -733,87 +720,74 @@ class ModelTpServer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
class ModelTpService(rpyc.Service):
|
def run_tp_server(
|
||||||
exposed_ModelTpServer = ModelTpServer
|
gpu_id: int,
|
||||||
|
tp_rank: int,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
nccl_port: int,
|
||||||
|
model_overide_args: dict,
|
||||||
|
):
|
||||||
|
"""Run a tensor parallel server."""
|
||||||
|
try:
|
||||||
|
model_server = ModelTpServer(
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
server_args,
|
||||||
|
nccl_port,
|
||||||
|
model_overide_args,
|
||||||
|
)
|
||||||
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||||
|
model_server.exposed_step(recv_reqs)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class ModelTpClient:
|
def launch_tp_servers(
|
||||||
def __init__(
|
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
||||||
self,
|
):
|
||||||
gpu_ids: List[int],
|
"""Launch multiple tensor parallel servers."""
|
||||||
server_args: ServerArgs,
|
procs = []
|
||||||
model_port_args: ModelPortArgs,
|
for i in tp_rank_range:
|
||||||
model_overide_args,
|
proc = multiprocessing.Process(
|
||||||
):
|
target=run_tp_server,
|
||||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
||||||
self.tp_size = server_args.tp_size
|
)
|
||||||
|
proc.start()
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
if self.tp_size * server_args.dp_size == 1:
|
return procs
|
||||||
# Init model
|
|
||||||
assert len(gpu_ids) == 1
|
|
||||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
|
||||||
gpu_ids[0],
|
|
||||||
0,
|
|
||||||
server_args,
|
|
||||||
model_port_args,
|
|
||||||
model_overide_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap functions
|
|
||||||
def async_wrap(f):
|
|
||||||
async def _func(*args, **kwargs):
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
|
|
||||||
return _func
|
def broadcast_recv_input(data, rank, dist_group):
|
||||||
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||||
|
|
||||||
self.step = async_wrap(self.model_server.exposed_step)
|
if rank == 0:
|
||||||
|
if len(data) == 0:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
serialized_data = pickle.dumps(data)
|
||||||
# Launch model processes
|
size = len(serialized_data)
|
||||||
if server_args.nnodes == 1:
|
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||||
self.procs = list(
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||||
executor.map(
|
|
||||||
lambda args: start_rpyc_service_process(*args),
|
|
||||||
[
|
|
||||||
(ModelTpService, p)
|
|
||||||
for p in model_port_args.model_tp_ports
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
|
||||||
else:
|
|
||||||
addrs = [
|
|
||||||
(ip, port)
|
|
||||||
for ip, port in zip(
|
|
||||||
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.model_services = list(
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
)
|
else:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
size = tensor_size.item()
|
||||||
|
|
||||||
# Init model
|
if size == 0:
|
||||||
def init_model(i):
|
return []
|
||||||
return self.model_services[i].ModelTpServer(
|
|
||||||
gpu_ids[i],
|
|
||||||
i,
|
|
||||||
server_args,
|
|
||||||
model_port_args,
|
|
||||||
model_overide_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||||
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
|
|
||||||
# Wrap functions
|
serialized_data = bytes(tensor_data.tolist())
|
||||||
def async_wrap(func_name):
|
data = pickle.loads(serialized_data)
|
||||||
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
return data
|
||||||
|
|
||||||
async def _func(*args, **kwargs):
|
|
||||||
tasks = [f(*args, **kwargs) for f in fs]
|
|
||||||
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
|
||||||
return obtain(tasks[0].value)
|
|
||||||
|
|
||||||
return _func
|
|
||||||
|
|
||||||
self.step = async_wrap("step")
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class TokenizerManager:
|
|||||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||||
|
|
||||||
self.send_to_router = context.socket(zmq.PUSH)
|
self.send_to_router = context.socket(zmq.PUSH)
|
||||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||||
|
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
|
|||||||
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
|
|||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
API_KEY_HEADER_NAME,
|
API_KEY_HEADER_NAME,
|
||||||
APIKeyValidatorMiddleware,
|
APIKeyValidatorMiddleware,
|
||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
receive_addrs,
|
|
||||||
send_addrs_to_rank_0,
|
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -98,6 +96,7 @@ async def flush_cache():
|
|||||||
|
|
||||||
|
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
|
"""Handle a generate request."""
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
def launch_server(server_args: ServerArgs,
|
||||||
|
model_overide_args: Optional[dict] = None,
|
||||||
|
pipe_finish_writer: Optional[mp.connection.Connection] = None):
|
||||||
|
"""Launch an HTTP server."""
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
# TODO: replace this with huggingface transformers template
|
# TODO: replace this with huggingface transformers template
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|
||||||
_set_global_server_args(server_args)
|
_set_global_server_args(server_args)
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
assert server_args.tp_size % server_args.nnodes == 0
|
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
server_args.port,
|
server_args.port,
|
||||||
server_args.additional_ports,
|
server_args.additional_ports,
|
||||||
tp_size_local,
|
|
||||||
server_args.dp_size,
|
server_args.dp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
ports = server_args.additional_ports
|
ports = server_args.additional_ports
|
||||||
model_port_args = []
|
|
||||||
for i in range(server_args.dp_size):
|
|
||||||
model_port_args.append(
|
|
||||||
ModelPortArgs(
|
|
||||||
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
|
||||||
model_tp_ips=[None] * tp_size_local,
|
|
||||||
model_tp_ports=ports[
|
|
||||||
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
port_args = PortArgs(
|
port_args = PortArgs(
|
||||||
tokenizer_port=ports[0],
|
tokenizer_port=ports[0],
|
||||||
router_port=ports[1],
|
controller_port=ports[1],
|
||||||
detokenizer_port=ports[2],
|
detokenizer_port=ports[2],
|
||||||
model_port_args=model_port_args,
|
nccl_ports=ports[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle multi-node tp
|
# Handle multi-node tensor parallelism
|
||||||
if server_args.nnodes > 1:
|
if server_args.nnodes > 1:
|
||||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||||
|
|
||||||
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
gpu_ids,
|
gpu_ids,
|
||||||
tp_rank_range,
|
tp_rank_range,
|
||||||
server_args,
|
server_args,
|
||||||
port_args.model_port_args[0],
|
ports[3],
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
if server_args.dp_size == 1:
|
if server_args.dp_size == 1:
|
||||||
start_process = start_controller_process_single
|
start_process = start_controller_process_single
|
||||||
else:
|
else:
|
||||||
start_process = start_controller_process_multi
|
start_process = start_controller_process_multi
|
||||||
proc_router = mp.Process(
|
proc_controller = mp.Process(
|
||||||
target=start_process,
|
target=start_process,
|
||||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||||
)
|
)
|
||||||
proc_router.start()
|
proc_controller.start()
|
||||||
proc_detoken = mp.Process(
|
proc_detoken = mp.Process(
|
||||||
target=start_detokenizer_process,
|
target=start_detokenizer_process,
|
||||||
args=(
|
args=(
|
||||||
@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
proc_detoken.start()
|
proc_detoken.start()
|
||||||
|
|
||||||
# Wait for the model to finish loading
|
# Wait for the model to finish loading
|
||||||
router_init_state = pipe_router_reader.recv()
|
controller_init_state = pipe_controller_reader.recv()
|
||||||
detoken_init_state = pipe_detoken_reader.recv()
|
detoken_init_state = pipe_detoken_reader.recv()
|
||||||
|
|
||||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||||
proc_router.kill()
|
proc_controller.kill()
|
||||||
proc_detoken.kill()
|
proc_detoken.kill()
|
||||||
print(
|
print(
|
||||||
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||||
|
|
||||||
if server_args.api_key and server_args.api_key != "":
|
if server_args.api_key and server_args.api_key != "":
|
||||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
def _wait_and_warmup():
|
t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
|
||||||
headers = {}
|
|
||||||
url = server_args.url()
|
|
||||||
if server_args.api_key:
|
|
||||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
|
||||||
|
|
||||||
# Wait until the server is launched
|
|
||||||
for _ in range(120):
|
|
||||||
time.sleep(0.5)
|
|
||||||
try:
|
|
||||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
|
||||||
break
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Send a warmup request
|
|
||||||
try:
|
|
||||||
for _ in range(server_args.dp_size):
|
|
||||||
res = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": "The capital city of France is",
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 8,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers=headers,
|
|
||||||
timeout=600,
|
|
||||||
)
|
|
||||||
assert res.status_code == 200
|
|
||||||
except Exception as e:
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send(get_exception_traceback())
|
|
||||||
print(f"Initialization failed. warmup error: {e}", flush=True)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send("init ok")
|
|
||||||
|
|
||||||
t = threading.Thread(target=_wait_and_warmup)
|
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
# Listen for requests
|
# Listen for requests
|
||||||
@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||||
|
headers = {}
|
||||||
|
url = server_args.url()
|
||||||
|
if server_args.api_key:
|
||||||
|
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||||
|
|
||||||
|
# Wait until the server is launched
|
||||||
|
for _ in range(120):
|
||||||
|
time.sleep(0.5)
|
||||||
|
try:
|
||||||
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
|
break
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Send a warmup request
|
||||||
|
try:
|
||||||
|
for _ in range(server_args.dp_size):
|
||||||
|
res = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital city of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
except Exception as e:
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(get_exception_traceback())
|
||||||
|
print(f"Initialization failed. warmup error: {e}", flush=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.info("The server is fired up and ready to roll!")
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send("init ok")
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
"""
|
"""
|
||||||
A wrapper for the server.
|
A wrapper for the server.
|
||||||
@@ -354,7 +341,6 @@ class Runtime:
|
|||||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||||
self.server_args.port,
|
self.server_args.port,
|
||||||
self.server_args.additional_ports,
|
self.server_args.additional_ports,
|
||||||
self.server_args.tp_size,
|
|
||||||
self.server_args.dp_size,
|
self.server_args.dp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -367,7 +353,7 @@ class Runtime:
|
|||||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=launch_server,
|
target=launch_server,
|
||||||
args=(self.server_args, pipe_writer, model_overide_args),
|
args=(self.server_args, model_overide_args, pipe_writer),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
pipe_writer.close()
|
pipe_writer.close()
|
||||||
|
|||||||
@@ -337,16 +337,9 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ModelPortArgs:
|
|
||||||
nccl_port: int
|
|
||||||
model_tp_ips: List[str]
|
|
||||||
model_tp_ports: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
tokenizer_port: int
|
tokenizer_port: int
|
||||||
router_port: int
|
controller_port: int
|
||||||
detokenizer_port: int
|
detokenizer_port: int
|
||||||
model_port_args: List[ModelPortArgs]
|
nccl_ports: List[int]
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
@@ -16,12 +15,10 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import rpyc
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from rpyc.utils.server import ThreadedServer
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -148,7 +145,6 @@ def is_port_available(port):
|
|||||||
def allocate_init_ports(
|
def allocate_init_ports(
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
additional_ports: Optional[List[int]] = None,
|
additional_ports: Optional[List[int]] = None,
|
||||||
tp_size: int = 1,
|
|
||||||
dp_size: int = 1,
|
dp_size: int = 1,
|
||||||
):
|
):
|
||||||
"""Allocate ports for all connections."""
|
"""Allocate ports for all connections."""
|
||||||
@@ -160,8 +156,8 @@ def allocate_init_ports(
|
|||||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||||
|
|
||||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
||||||
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
num_ports_needed = 4 + dp_size
|
||||||
while len(ret_ports) < num_ports_needed:
|
while len(ret_ports) < num_ports_needed:
|
||||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||||
ret_ports.append(cur_port)
|
ret_ports.append(cur_port)
|
||||||
@@ -371,49 +367,6 @@ def load_image(image_file):
|
|||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
def connect_rpyc_service(host, port):
|
|
||||||
repeat_count = 0
|
|
||||||
while repeat_count < 20:
|
|
||||||
try:
|
|
||||||
con = rpyc.connect(
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
config={
|
|
||||||
"allow_public_attrs": True,
|
|
||||||
"allow_pickle": True,
|
|
||||||
"sync_request_timeout": 3600,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except ConnectionRefusedError as e:
|
|
||||||
time.sleep(1)
|
|
||||||
repeat_count += 1
|
|
||||||
if repeat_count == 20:
|
|
||||||
raise RuntimeError(f"Connect rpyc error: {e}")
|
|
||||||
|
|
||||||
return con.root
|
|
||||||
|
|
||||||
|
|
||||||
def start_rpyc_service(service: rpyc.Service, port: int):
|
|
||||||
t = ThreadedServer(
|
|
||||||
service=service,
|
|
||||||
port=port,
|
|
||||||
protocol_config={
|
|
||||||
"allow_public_attrs": True,
|
|
||||||
"allow_pickle": True,
|
|
||||||
"sync_request_timeout": 3600,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
t.logger.setLevel(logging.WARN)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
|
||||||
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
|
||||||
proc.start()
|
|
||||||
return proc
|
|
||||||
|
|
||||||
|
|
||||||
def suppress_other_loggers():
|
def suppress_other_loggers():
|
||||||
from vllm.logger import logger as vllm_default_logger
|
from vllm.logger import logger as vllm_default_logger
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user