Support data parallelism (static) (#480)
Co-authored-by: Ying Sheng <ying.sheng@databricks.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -26,7 +26,8 @@ class GlobalConfig:
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
# Request dependency time due to network delay
|
||||
self.request_dependency_time = 0.03
|
||||
self.request_dependency_delay = 0.03
|
||||
self.wait_for_new_request_delay = 0.0006
|
||||
|
||||
# New generation token ratio estimation
|
||||
self.base_new_token_ratio = 0.4
|
||||
|
||||
@@ -5,7 +5,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
|
||||
@@ -5,7 +5,7 @@ from torch import nn
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
@@ -20,7 +20,7 @@ class RadixAttention(nn.Module):
|
||||
|
||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
|
||||
102
python/sglang/srt/managers/controller/dp_worker.py
Normal file
102
python/sglang/srt/managers/controller/dp_worker.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""A data parallel worker thread."""
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Callable
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
CHECKING_INTERVAL = 5
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class DataParallelWorkerThread(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
request_queue: queue.Queue,
|
||||
detokenizer_port: int,
|
||||
step_func: Callable,
|
||||
):
|
||||
super(DataParallelWorkerThread, self).__init__()
|
||||
self.worker_id = worker_id
|
||||
self.request_queue = request_queue
|
||||
self.liveness = True
|
||||
self.request_dependency_delay = global_config.request_dependency_delay
|
||||
|
||||
context = zmq.asyncio.Context()
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
||||
|
||||
self.step = step_func
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while self.liveness:
|
||||
requests = []
|
||||
while not self.request_queue.empty():
|
||||
requests.append(self.request_queue.get())
|
||||
try:
|
||||
out_pyobjs = await self.step(requests)
|
||||
except Exception:
|
||||
for r in requests:
|
||||
self.request_queue.put(r)
|
||||
logger.error(
|
||||
f"Worker thread {self.worker_id}: "
|
||||
f"failed to get back from Model Server\n"
|
||||
f"{get_exception_traceback()}"
|
||||
)
|
||||
self.liveness = False
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.request_dependency_delay)
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def monitoring(self):
|
||||
while True:
|
||||
await asyncio.sleep(CHECKING_INTERVAL)
|
||||
# can plug in monitoring logic here
|
||||
|
||||
def run(self):
|
||||
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(self.monitoring())
|
||||
loop.run_until_complete(self.loop_for_forward())
|
||||
|
||||
|
||||
def start_data_parallel_worker(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args,
|
||||
gpu_ids: List[int],
|
||||
worker_id: int,
|
||||
):
|
||||
model_tp_client = ModelTpClient(
|
||||
gpu_ids,
|
||||
server_args,
|
||||
port_args.model_port_args[worker_id],
|
||||
model_overide_args,
|
||||
)
|
||||
worker_thread = DataParallelWorkerThread(
|
||||
worker_id=worker_id,
|
||||
request_queue=queue.Queue(),
|
||||
detokenizer_port=port_args.detokenizer_port,
|
||||
step_func=model_tp_client.step,
|
||||
)
|
||||
worker_thread.start()
|
||||
return worker_thread
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Meta data for requests and batches"""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List
|
||||
@@ -5,7 +6,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
|
||||
|
||||
187
python/sglang/srt/managers/controller/manager_multi.py
Normal file
187
python/sglang/srt/managers/controller/manager_multi.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
A controller that manages multiple data parallel workers.
|
||||
Each data parallel worker can manage multiple tensor parallel workers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum, auto
|
||||
from typing import Dict
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.controller.dp_worker import (
|
||||
DataParallelWorkerThread,
|
||||
start_data_parallel_worker,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, method: str):
|
||||
method = method.upper()
|
||||
try:
|
||||
return cls[method]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
class Controller:
|
||||
def __init__(
|
||||
self,
|
||||
load_balance_method: str,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
|
||||
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
||||
self.round_robin_counter = 0
|
||||
|
||||
self.dispatch_lookup = {
|
||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||
}
|
||||
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context()
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
# Init status
|
||||
self.recv_reqs = []
|
||||
|
||||
# Start data parallel workers
|
||||
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
def start_dp_worker(i):
|
||||
try:
|
||||
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
||||
worker_thread = start_data_parallel_worker(
|
||||
server_args, port_args, model_overide_args, gpu_ids, i
|
||||
)
|
||||
self.workers[i] = worker_thread
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(server_args.dp_size) as executor:
|
||||
executor.map(start_dp_worker, range(server_args.dp_size))
|
||||
|
||||
def have_any_live_worker(self):
|
||||
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
||||
|
||||
def put_req_to_worker(self, worker_id, req):
|
||||
self.workers[worker_id].request_queue.put(req)
|
||||
|
||||
async def round_robin_scheduler(self, input_requests):
|
||||
available_workers = list(self.workers.keys())
|
||||
for r in input_requests:
|
||||
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
available_workers
|
||||
)
|
||||
return
|
||||
|
||||
async def shortest_queue_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
worker = min(
|
||||
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
||||
)
|
||||
self.put_req_to_worker(worker, r)
|
||||
return
|
||||
|
||||
async def remove_dead_workers(self):
|
||||
for i in list(self.workers.keys()):
|
||||
worker_thread = self.workers[i]
|
||||
if not worker_thread.liveness:
|
||||
worker_thread.join()
|
||||
# move unsuccessful requests back to the queue
|
||||
while not worker_thread.request_queue.empty():
|
||||
self.recv_reqs.append(worker_thread.request_queue.get())
|
||||
del self.workers[i]
|
||||
logger.info(f"Stale worker {i} removed")
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
await self.remove_dead_workers()
|
||||
|
||||
if self.have_any_live_worker():
|
||||
next_step_input = list(self.recv_reqs)
|
||||
self.recv_reqs = []
|
||||
if next_step_input:
|
||||
await self.dispatching(next_step_input)
|
||||
#else:
|
||||
# logger.error("There is no live worker.")
|
||||
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||
if isinstance(recv_req, FlushCacheReq):
|
||||
# TODO(lsyin): apply more specific flushCacheReq
|
||||
for worker_thread in self.workers.values():
|
||||
worker_thread.request_queue.put(recv_req)
|
||||
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.recv_reqs.append(recv_req)
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
in_queue = False
|
||||
for i, req in enumerate(self.recv_reqs):
|
||||
if req.rid == recv_req.rid:
|
||||
self.recv_reqs[i] = recv_req
|
||||
in_queue = True
|
||||
break
|
||||
if not in_queue:
|
||||
# Send abort req to all TP groups
|
||||
for worker in list(self.workers.keys()):
|
||||
self.put_req_to_worker(worker, recv_req)
|
||||
else:
|
||||
logger.error(f"Invalid object: {recv_req}")
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
model_overide_args=None,
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
controller = Controller(
|
||||
server_args.load_balance_method, server_args, port_args, model_overide_args
|
||||
)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
@@ -1,3 +1,4 @@
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -6,15 +7,15 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class RouterManager:
|
||||
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
||||
class ControllerSingle:
|
||||
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
@@ -30,7 +31,7 @@ class RouterManager:
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init some configs
|
||||
self.request_dependency_time = global_config.request_dependency_time
|
||||
self.request_dependency_delay = global_config.request_dependency_delay
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
@@ -46,12 +47,12 @@ class RouterManager:
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
if self.request_dependency_time > 0:
|
||||
if self.request_dependency_delay > 0:
|
||||
slept = True
|
||||
await asyncio.sleep(self.request_dependency_time)
|
||||
await asyncio.sleep(self.request_dependency_delay)
|
||||
|
||||
if not slept:
|
||||
await asyncio.sleep(0.0006)
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
@@ -59,7 +60,7 @@ class RouterManager:
|
||||
self.recv_reqs.append(recv_req)
|
||||
|
||||
|
||||
def start_router_process(
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
||||
):
|
||||
logging.basicConfig(
|
||||
@@ -68,8 +69,13 @@ def start_router_process(
|
||||
)
|
||||
|
||||
try:
|
||||
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
||||
router = RouterManager(model_client, port_args)
|
||||
model_client = ModelTpClient(
|
||||
list(range(server_args.tp_size)),
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
model_overide_args,
|
||||
)
|
||||
controller = ControllerSingle(model_client, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
@@ -78,5 +84,5 @@ def start_router_process(
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(router.loop_for_recv_requests())
|
||||
loop.run_until_complete(router.loop_for_forward())
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
||||
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
|
||||
# for server args in model endpoints
|
||||
global_server_args_dict = {}
|
||||
@@ -215,14 +215,16 @@ class ModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
mem_fraction_static,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
nccl_port,
|
||||
mem_fraction_static: float,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
nccl_port: int,
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.gpu_id = gpu_id
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
@@ -235,9 +237,9 @@ class ModelRunner:
|
||||
}
|
||||
|
||||
# Init torch distributed
|
||||
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
|
||||
torch.cuda.set_device(self.tp_rank)
|
||||
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
@@ -245,22 +247,26 @@ class ModelRunner:
|
||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
||||
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
||||
raise ValueError(
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
)
|
||||
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
|
||||
def load_model(self):
|
||||
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
device_config = DeviceConfig()
|
||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||
@@ -286,12 +292,16 @@ class ModelRunner:
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
)
|
||||
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
||||
f"Type={type(self.model).__name__}. "
|
||||
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||
f"Type={type(self.model).__name__}. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||
@@ -306,7 +316,7 @@ class ModelRunner:
|
||||
|
||||
if self.max_total_num_tokens <= 0:
|
||||
raise RuntimeError(
|
||||
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
||||
"Not enought memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
@@ -320,6 +330,10 @@ class ModelRunner:
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(self, batch: Batch):
|
||||
@@ -424,8 +438,8 @@ def import_model_classes():
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
if isinstance(entry, list): # To support multiple model classes in one module
|
||||
for cls in entry:
|
||||
model_arch_name_to_cls[cls.__name__] = cls
|
||||
for tmp in entry:
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
return model_arch_name_to_cls
|
||||
@@ -442,4 +456,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
|
||||
|
||||
# Monkey patch model loader
|
||||
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
||||
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
||||
@@ -2,7 +2,7 @@ import random
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Scheduler:
|
||||
class ScheduleHeuristic:
|
||||
def __init__(
|
||||
self,
|
||||
schedule_heuristic,
|
||||
@@ -1,20 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
|
||||
try:
|
||||
from vllm.logger import _default_handler as vllm_default_logger
|
||||
except ImportError:
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
start_rpyc_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logger = logging.getLogger("srt.model_tp")
|
||||
|
||||
|
||||
class ModelRpcServer:
|
||||
class ModelTpServer:
|
||||
def __init__(
|
||||
self,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
suppress_other_loggers()
|
||||
|
||||
# Copy arguments
|
||||
self.gpu_id = gpu_id
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
|
||||
@@ -68,16 +64,16 @@ class ModelRpcServer:
|
||||
context_length=server_args.context_length,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
# For model end global settings
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=port_args.nccl_port,
|
||||
nccl_port=model_port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
@@ -95,21 +91,21 @@ class ModelRpcServer:
|
||||
self.max_prefill_tokens = max(
|
||||
self.model_config.context_len,
|
||||
(
|
||||
self.max_total_num_tokens // 6
|
||||
min(self.max_total_num_tokens // 6, 65536)
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
),
|
||||
)
|
||||
self.max_running_requests = (self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None else server_args.max_running_requests)
|
||||
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
logger.info(f"[rank={self.tp_rank}] "
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] "
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
@@ -124,7 +120,7 @@ class ModelRpcServer:
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = Scheduler(
|
||||
self.scheduler = ScheduleHeuristic(
|
||||
self.schedule_heuristic,
|
||||
self.max_running_requests,
|
||||
self.max_prefill_tokens,
|
||||
@@ -170,7 +166,7 @@ class ModelRpcServer:
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
if self.tp_size * self.dp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
@@ -188,7 +184,7 @@ class ModelRpcServer:
|
||||
# Forward
|
||||
self.forward_step()
|
||||
except Exception:
|
||||
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
||||
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
|
||||
|
||||
# Return results
|
||||
ret = self.out_pyobjs
|
||||
@@ -224,16 +220,17 @@ class ModelRpcServer:
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
throuhgput = self.num_generated_tokens / (
|
||||
throughput = self.num_generated_tokens / (
|
||||
time.time() - self.last_stats_tic
|
||||
)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throuhgput:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
@@ -405,7 +402,7 @@ class ModelRpcServer:
|
||||
f"#new_token: {new_batch_input_tokens}. "
|
||||
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
||||
f"#running_req: {running_req}. "
|
||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. "
|
||||
)
|
||||
# logger.debug(
|
||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
@@ -724,20 +721,30 @@ class ModelRpcServer:
|
||||
break
|
||||
|
||||
|
||||
class ModelRpcService(rpyc.Service):
|
||||
exposed_ModelRpcServer = ModelRpcServer
|
||||
class ModelTpService(rpyc.Service):
|
||||
exposed_ModelTpServer = ModelTpServer
|
||||
|
||||
|
||||
class ModelRpcClient:
|
||||
class ModelTpClient:
|
||||
def __init__(
|
||||
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
||||
self,
|
||||
gpu_ids: List[int],
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
tp_size = server_args.tp_size
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
self.tp_size = server_args.tp_size
|
||||
|
||||
if tp_size == 1:
|
||||
if self.tp_size * server_args.dp_size == 1:
|
||||
# Init model
|
||||
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
||||
0, server_args, port_args, model_overide_args
|
||||
assert len(gpu_ids) == 1
|
||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||
0,
|
||||
gpu_ids[0],
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
# Wrap functions
|
||||
@@ -749,19 +756,26 @@ class ModelRpcClient:
|
||||
|
||||
self.step = async_wrap(self.model_server.exposed_step)
|
||||
else:
|
||||
with ThreadPoolExecutor(tp_size) as executor:
|
||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||
# Launch model processes
|
||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
||||
self.remote_services = [x[0] for x in rets]
|
||||
rets = executor.map(
|
||||
lambda args: start_rpyc_process(*args),
|
||||
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||
)
|
||||
self.model_services = [x[0] for x in rets]
|
||||
self.procs = [x[1] for x in rets]
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.remote_services[i].ModelRpcServer(
|
||||
i, server_args, port_args, model_overide_args
|
||||
return self.model_services[i].ModelTpServer(
|
||||
gpu_ids[i],
|
||||
i,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
self.model_servers = executor.map(init_model, range(tp_size))
|
||||
self.model_servers = executor.map(init_model, range(self.tp_size))
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
@@ -774,45 +788,4 @@ class ModelRpcClient:
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
|
||||
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcService(),
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.start()
|
||||
|
||||
|
||||
def start_model_process(port):
|
||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
||||
proc.start()
|
||||
time.sleep(1)
|
||||
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
"localhost",
|
||||
port,
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
|
||||
assert proc.is_alive()
|
||||
return con.root, proc
|
||||
self.step = async_wrap("step")
|
||||
@@ -27,7 +27,6 @@ class GenerateReqInput:
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output
|
||||
stream: bool = False
|
||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
||||
|
||||
def post_init(self):
|
||||
|
||||
@@ -135,4 +134,4 @@ class AbortReq:
|
||||
|
||||
@dataclass
|
||||
class DetokenizeReqInput:
|
||||
input_ids: List[int]
|
||||
input_ids: List[int]
|
||||
@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
@torch.compile
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
|
||||
@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.fused_moe import fused_moe
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
use_fused = True
|
||||
|
||||
@@ -4,9 +4,13 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
|
||||
@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
|
||||
@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
|
||||
@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
||||
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api_adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware,
|
||||
@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
# Allocate ports
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port, server_args.additional_ports, server_args.tp_size
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
server_args.tp_size,
|
||||
server_args.dp_size,
|
||||
)
|
||||
|
||||
# Init local models port args
|
||||
ports = server_args.additional_ports
|
||||
tp = server_args.tp_size
|
||||
model_port_args = []
|
||||
for i in range(server_args.dp_size):
|
||||
model_port_args.append(
|
||||
ModelPortArgs(
|
||||
nccl_port=ports[3 + i * (tp + 1)],
|
||||
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
|
||||
)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=server_args.additional_ports[0],
|
||||
router_port=server_args.additional_ports[1],
|
||||
detokenizer_port=server_args.additional_ports[2],
|
||||
nccl_port=server_args.additional_ports[3],
|
||||
model_rpc_ports=server_args.additional_ports[4:],
|
||||
tokenizer_port=ports[0],
|
||||
router_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
model_port_args=model_port_args,
|
||||
)
|
||||
|
||||
# Launch processes
|
||||
@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_process = start_controller_process_single
|
||||
else:
|
||||
start_process = start_controller_process_multi
|
||||
proc_router = mp.Process(
|
||||
target=start_router_process,
|
||||
target=start_process,
|
||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||
)
|
||||
proc_router.start()
|
||||
@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
log_evel: str = "error",
|
||||
log_level: str = "error",
|
||||
model_overide_args: Optional[dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# Pre-allocate ports
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port,
|
||||
self.server_args.additional_ports,
|
||||
self.server_args.tp_size,
|
||||
self.server_args.dp_size,
|
||||
)
|
||||
|
||||
self.url = self.server_args.url()
|
||||
|
||||
@@ -44,6 +44,10 @@ class ServerArgs:
|
||||
# Other
|
||||
api_key: str = ""
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Optimization/debug options
|
||||
enable_flashinfer: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
@@ -226,6 +230,24 @@ class ServerArgs:
|
||||
help="Set API key of the server",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
"--dp-size",
|
||||
type=int,
|
||||
default=ServerArgs.dp_size,
|
||||
help="Data parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-balance-method",
|
||||
type=str,
|
||||
default=ServerArgs.load_balance_method,
|
||||
help="Load balancing strategy for data parallelism.",
|
||||
choices=[
|
||||
"round_robin",
|
||||
"shortest_queue",
|
||||
],
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
@@ -271,10 +293,15 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelPortArgs:
|
||||
nccl_port: int
|
||||
model_tp_ports: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
tokenizer_port: int
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
model_port_args: List[ModelPortArgs]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import multiprocessing
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -12,12 +13,14 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import rpyc
|
||||
import torch
|
||||
import triton
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
"""Set the random seed for all libraries."""
|
||||
random.seed(seed)
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def is_port_available(port):
|
||||
"""Return whether a port is available."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
@@ -142,7 +147,9 @@ def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
"""Allocate ports for all connections."""
|
||||
if additional_ports:
|
||||
ret_ports = [port] + additional_ports
|
||||
else:
|
||||
@@ -151,20 +158,23 @@ def allocate_init_ports(
|
||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||
|
||||
while len(ret_ports) < 5 + tp_size:
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
||||
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
||||
while len(ret_ports) < num_ports_needed:
|
||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||
ret_ports.append(cur_port)
|
||||
cur_port += 1
|
||||
|
||||
if port and ret_ports[0] != port:
|
||||
if port is not None and ret_ports[0] != port:
|
||||
logger.warn(
|
||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||
)
|
||||
|
||||
return ret_ports[0], ret_ports[1:]
|
||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
"""Get the logit bias for integer-only tokens."""
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
|
||||
if int(triton.__version__.split(".")[0]) >= 3:
|
||||
return None
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
kernels = kernel.cache[rank].values()
|
||||
gpu_id = torch.cuda.current_device()
|
||||
kernels = kernel.cache[gpu_id].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
@@ -363,6 +369,63 @@ def load_image(image_file):
|
||||
return image, image_size
|
||||
|
||||
|
||||
def init_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def connect_to_rpyc_service(port, host="localhost"):
|
||||
time.sleep(1)
|
||||
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
host,
|
||||
port,
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
|
||||
return con.root
|
||||
|
||||
|
||||
def start_rpyc_process(service: rpyc.Service, port: int):
|
||||
# Return the proxy and the process
|
||||
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
||||
proc.start()
|
||||
proxy = connect_to_rpyc_service(port)
|
||||
assert proc.is_alive()
|
||||
return proxy, proc
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str):
|
||||
try:
|
||||
installed_version = version(pkg)
|
||||
@@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user