Support data parallelism (static) (#480)

Co-authored-by: Ying Sheng <ying.sheng@databricks.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Ying Sheng
2024-05-27 21:24:10 -07:00
committed by GitHub
parent 565d727409
commit 0463f7fb52
32 changed files with 580 additions and 181 deletions

View File

@@ -26,7 +26,8 @@ class GlobalConfig:
self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay
self.request_dependency_time = 0.03
self.request_dependency_delay = 0.03
self.wait_for_new_request_delay = 0.0006
# New generation token ratio estimation
self.base_new_token_ratio = 0.4

View File

@@ -5,7 +5,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class LogitsProcessor(nn.Module):

View File

@@ -5,7 +5,7 @@ from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module):
@@ -20,7 +20,7 @@ class RadixAttention(nn.Module):
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.managers.controller.model_runner import global_server_args_dict
if global_server_args_dict.get("enable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer

View File

@@ -5,7 +5,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.managers.controller.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False):

View File

@@ -0,0 +1,102 @@
"""A data parallel worker thread."""
import asyncio
import logging
import queue
import threading
from typing import List, Callable
import uvloop
import zmq
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
CHECKING_INTERVAL = 5
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DataParallelWorkerThread(threading.Thread):
def __init__(
self,
worker_id: int,
request_queue: queue.Queue,
detokenizer_port: int,
step_func: Callable,
):
super(DataParallelWorkerThread, self).__init__()
self.worker_id = worker_id
self.request_queue = request_queue
self.liveness = True
self.request_dependency_delay = global_config.request_dependency_delay
context = zmq.asyncio.Context()
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
self.step = step_func
async def loop_for_forward(self):
while self.liveness:
requests = []
while not self.request_queue.empty():
requests.append(self.request_queue.get())
try:
out_pyobjs = await self.step(requests)
except Exception:
for r in requests:
self.request_queue.put(r)
logger.error(
f"Worker thread {self.worker_id}: "
f"failed to get back from Model Server\n"
f"{get_exception_traceback()}"
)
self.liveness = False
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def monitoring(self):
while True:
await asyncio.sleep(CHECKING_INTERVAL)
# can plug in monitoring logic here
def run(self):
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.monitoring())
loop.run_until_complete(self.loop_for_forward())
def start_data_parallel_worker(
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
gpu_ids: List[int],
worker_id: int,
):
model_tp_client = ModelTpClient(
gpu_ids,
server_args,
port_args.model_port_args[worker_id],
model_overide_args,
)
worker_thread = DataParallelWorkerThread(
worker_id=worker_id,
request_queue=queue.Queue(),
detokenizer_port=port_args.detokenizer_port,
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread

View File

@@ -1,3 +1,4 @@
"""Meta data for requests and batches"""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
@@ -5,7 +6,7 @@ from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool

View File

@@ -0,0 +1,187 @@
"""
A controller that manages multiple data parallel workers.
Each data parallel worker can manage multiple tensor parallel workers.
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict
import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.io_struct import (
AbortReq,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.dp_worker import (
DataParallelWorkerThread,
start_data_parallel_worker,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class Controller:
def __init__(
self,
load_balance_method: str,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
):
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
self.server_args = server_args
self.port_args = port_args
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
self.round_robin_counter = 0
self.dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = self.dispatch_lookup[self.load_balance_method]
# Init communication
context = zmq.asyncio.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
# Init status
self.recv_reqs = []
# Start data parallel workers
self.workers: Dict[int, DataParallelWorkerThread] = {}
tp_size = server_args.tp_size
def start_dp_worker(i):
try:
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
worker_thread = start_data_parallel_worker(
server_args, port_args, model_overide_args, gpu_ids, i
)
self.workers[i] = worker_thread
except Exception:
logger.error(
f"Failed to start local worker {i}\n{get_exception_traceback()}"
)
with ThreadPoolExecutor(server_args.dp_size) as executor:
executor.map(start_dp_worker, range(server_args.dp_size))
def have_any_live_worker(self):
return any(worker_thread.liveness for worker_thread in self.workers.values())
def put_req_to_worker(self, worker_id, req):
self.workers[worker_id].request_queue.put(req)
async def round_robin_scheduler(self, input_requests):
available_workers = list(self.workers.keys())
for r in input_requests:
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
available_workers
)
return
async def shortest_queue_scheduler(self, input_requests):
for r in input_requests:
worker = min(
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
)
self.put_req_to_worker(worker, r)
return
async def remove_dead_workers(self):
for i in list(self.workers.keys()):
worker_thread = self.workers[i]
if not worker_thread.liveness:
worker_thread.join()
# move unsuccessful requests back to the queue
while not worker_thread.request_queue.empty():
self.recv_reqs.append(worker_thread.request_queue.get())
del self.workers[i]
logger.info(f"Stale worker {i} removed")
async def loop_for_forward(self):
while True:
await self.remove_dead_workers()
if self.have_any_live_worker():
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
if next_step_input:
await self.dispatching(next_step_input)
#else:
# logger.error("There is no live worker.")
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
if isinstance(recv_req, FlushCacheReq):
# TODO(lsyin): apply more specific flushCacheReq
for worker_thread in self.workers.values():
worker_thread.request_queue.put(recv_req)
elif isinstance(recv_req, TokenizedGenerateReqInput):
self.recv_reqs.append(recv_req)
elif isinstance(recv_req, AbortReq):
in_queue = False
for i, req in enumerate(self.recv_reqs):
if req.rid == recv_req.rid:
self.recv_reqs[i] = recv_req
in_queue = True
break
if not in_queue:
# Send abort req to all TP groups
for worker in list(self.workers.keys()):
self.put_req_to_worker(worker, recv_req)
else:
logger.error(f"Invalid object: {recv_req}")
def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_overide_args=None,
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
controller = Controller(
server_args.load_balance_method, server_args, port_args, model_overide_args
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())

View File

@@ -1,3 +1,4 @@
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import logging
@@ -6,15 +7,15 @@ import zmq
import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class RouterManager:
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
# Init communication
context = zmq.asyncio.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
@@ -30,7 +31,7 @@ class RouterManager:
self.recv_reqs = []
# Init some configs
self.request_dependency_time = global_config.request_dependency_time
self.request_dependency_delay = global_config.request_dependency_delay
async def loop_for_forward(self):
while True:
@@ -46,12 +47,12 @@ class RouterManager:
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
if self.request_dependency_time > 0:
if self.request_dependency_delay > 0:
slept = True
await asyncio.sleep(self.request_dependency_time)
await asyncio.sleep(self.request_dependency_delay)
if not slept:
await asyncio.sleep(0.0006)
await asyncio.sleep(global_config.wait_for_new_request_delay)
async def loop_for_recv_requests(self):
while True:
@@ -59,7 +60,7 @@ class RouterManager:
self.recv_reqs.append(recv_req)
def start_router_process(
def start_controller_process(
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
):
logging.basicConfig(
@@ -68,8 +69,13 @@ def start_router_process(
)
try:
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
router = RouterManager(model_client, port_args)
model_client = ModelTpClient(
list(range(server_args.tp_size)),
server_args,
port_args.model_port_args[0],
model_overide_args,
)
controller = ControllerSingle(model_client, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
@@ -78,5 +84,5 @@ def start_router_process(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_recv_requests())
loop.run_until_complete(router.loop_for_forward())
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())

View File

@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
logger = logging.getLogger("model_runner")
logger = logging.getLogger("srt.model_runner")
# for server args in model endpoints
global_server_args_dict = {}
@@ -215,14 +215,16 @@ class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static,
tp_rank,
tp_size,
nccl_port,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
tp_size: int,
nccl_port: int,
server_args: ServerArgs,
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
@@ -235,9 +237,9 @@ class ModelRunner:
}
# Init torch distributed
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
torch.cuda.set_device(self.tp_rank)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
@@ -245,22 +247,26 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.info(f"[rank={self.tp_rank}] Init torch end.")
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
@@ -286,12 +292,16 @@ class ModelRunner:
parallel_config=None,
scheduler_config=None,
)
logger.info(f"[rank={self.tp_rank}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
@@ -306,7 +316,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static."
"Not enought memory. Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
@@ -320,6 +330,10 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
@torch.inference_mode()
def forward_prefill(self, batch: Batch):
@@ -424,8 +438,8 @@ def import_model_classes():
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module
for cls in entry:
model_arch_name_to_cls[cls.__name__] = cls
for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp
else:
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
@@ -442,4 +456,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)

View File

@@ -2,7 +2,7 @@ import random
from collections import defaultdict
class Scheduler:
class ScheduleHeuristic:
def __init__(
self,
schedule_heuristic,

View File

@@ -1,20 +1,13 @@
import asyncio
import logging
import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from typing import List
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
try:
from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
from vllm.logger import logger as vllm_default_logger
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logger = logging.getLogger("srt.model_tp")
class ModelRpcServer:
class ModelTpServer:
def __init__(
self,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: Optional[dict] = None,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
@@ -68,16 +64,16 @@ class ModelRpcServer:
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
# For model end global settings
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
nccl_port=model_port_args.nccl_port,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
@@ -95,21 +91,21 @@ class ModelRpcServer:
self.max_prefill_tokens = max(
self.model_config.context_len,
(
self.max_total_num_tokens // 6
min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(f"[rank={self.tp_rank}] "
logger.info(
f"[gpu_id={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, "
@@ -124,7 +120,7 @@ class ModelRpcServer:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.max_running_requests,
self.max_prefill_tokens,
@@ -170,7 +166,7 @@ class ModelRpcServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
if self.tp_size * self.dp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
@@ -188,7 +184,7 @@ class ModelRpcServer:
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
@@ -224,16 +220,17 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
@@ -405,7 +402,7 @@ class ModelRpcServer:
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. "
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
@@ -724,20 +721,30 @@ class ModelRpcServer:
break
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelTpService(rpyc.Service):
exposed_ModelTpServer = ModelTpServer
class ModelRpcClient:
class ModelTpClient:
def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
tp_size = server_args.tp_size
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
if tp_size == 1:
if self.tp_size * server_args.dp_size == 1:
# Init model
self.model_server = ModelRpcService().exposed_ModelRpcServer(
0, server_args, port_args, model_overide_args
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
server_args,
model_port_args,
model_overide_args,
)
# Wrap functions
@@ -749,19 +756,26 @@ class ModelRpcClient:
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.remote_services = [x[0] for x in rets]
rets = executor.map(
lambda args: start_rpyc_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
)
self.model_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.remote_services[i].ModelRpcServer(
i, server_args, port_args, model_overide_args
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
self.model_servers = executor.map(init_model, range(tp_size))
self.model_servers = executor.map(init_model, range(self.tp_size))
# Wrap functions
def async_wrap(func_name):
@@ -774,45 +788,4 @@ class ModelRpcClient:
return _func
self.step = async_wrap("step")
def _init_service(port):
t = ThreadedServer(
ModelRpcService(),
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc
self.step = async_wrap("step")

View File

@@ -27,7 +27,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self):
@@ -135,4 +134,4 @@ class AbortReq:
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
input_ids: List[int]

View File

@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
@torch.compile

View File

@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class DbrxRouter(nn.Module):

View File

@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class GemmaMLP(nn.Module):

View File

@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
use_fused = True

View File

@@ -4,9 +4,13 @@
from typing import Any, Dict, Optional, Tuple, Iterable
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class LlamaMLP(nn.Module):
@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue

View File

@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,

View File

@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,

View File

@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata

View File

@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMLP(nn.Module):

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class QWenMLP(nn.Module):

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
Qwen2Config = None

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class StablelmMLP(nn.Module):

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from http import HTTPStatus
from typing import List, Optional, Union
from typing import Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api,
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.additional_ports, server_args.tp_size
server_args.port,
server_args.additional_ports,
server_args.tp_size,
server_args.dp_size,
)
# Init local models port args
ports = server_args.additional_ports
tp = server_args.tp_size
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp + 1)],
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
)
)
port_args = PortArgs(
tokenizer_port=server_args.additional_ports[0],
router_port=server_args.additional_ports[1],
detokenizer_port=server_args.additional_ports[2],
nccl_port=server_args.additional_ports[3],
model_rpc_ports=server_args.additional_ports[4:],
tokenizer_port=ports[0],
router_port=ports[1],
detokenizer_port=ports[2],
model_port_args=model_port_args,
)
# Launch processes
@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_router = mp.Process(
target=start_router_process,
target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args),
)
proc_router.start()
@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
class Runtime:
def __init__(
self,
log_evel: str = "error",
log_level: str = "error",
model_overide_args: Optional[dict] = None,
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
self.server_args.dp_size,
)
self.url = self.server_args.url()

View File

@@ -44,6 +44,10 @@ class ServerArgs:
# Other
api_key: str = ""
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Optimization/debug options
enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False
@@ -226,6 +230,24 @@ class ServerArgs:
help="Set API key of the server",
)
# Data parallelism
parser.add_argument(
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="Data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="Load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
@@ -271,10 +293,15 @@ class ServerArgs:
)
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ports: List[int]
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
model_port_args: List[ModelPortArgs]

View File

@@ -1,6 +1,7 @@
"""Common utilities."""
import base64
import multiprocessing
import logging
import os
import random
@@ -12,12 +13,14 @@ from typing import List, Optional
import numpy as np
import requests
import rpyc
import torch
import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
import torch.distributed as dist
logger = logging.getLogger(__name__)
@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -142,7 +147,9 @@ def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
@@ -151,20 +158,23 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
while len(ret_ports) < 5 + tp_size:
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port and ret_ports[0] != port:
if port is not None and ret_ports[0] != port:
logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:]
return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
if int(triton.__version__.split(".")[0]) >= 3:
return None
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
kernels = kernel.cache[rank].values()
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
@@ -363,6 +369,63 @@ def load_image(image_file):
return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
return con.root
def start_rpyc_process(service: rpyc.Service, port: int):
# Return the proxy and the process
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
proc.start()
proxy = connect_to_rpyc_service(port)
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str):
try:
installed_version = version(pkg)
@@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
return response