Support data parallelism (static) (#480)
Co-authored-by: Ying Sheng <ying.sheng@databricks.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -26,7 +26,8 @@ class GlobalConfig:
|
|||||||
self.concate_and_append_mode = "no_adjust"
|
self.concate_and_append_mode = "no_adjust"
|
||||||
|
|
||||||
# Request dependency time due to network delay
|
# Request dependency time due to network delay
|
||||||
self.request_dependency_time = 0.03
|
self.request_dependency_delay = 0.03
|
||||||
|
self.wait_for_new_request_delay = 0.0006
|
||||||
|
|
||||||
# New generation token ratio estimation
|
# New generation token ratio estimation
|
||||||
self.base_new_token_ratio = 0.4
|
self.base_new_token_ratio = 0.4
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from torch import nn
|
|||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
@@ -20,7 +20,7 @@ class RadixAttention(nn.Module):
|
|||||||
|
|
||||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||||
|
|
||||||
if global_server_args_dict.get("enable_flashinfer", False):
|
if global_server_args_dict.get("enable_flashinfer", False):
|
||||||
self.prefill_forward = self.prefill_forward_flashinfer
|
self.prefill_forward = self.prefill_forward_flashinfer
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
from sglang.srt.utils import wrap_kernel_launcher
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
|
|||||||
102
python/sglang/srt/managers/controller/dp_worker.py
Normal file
102
python/sglang/srt/managers/controller/dp_worker.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""A data parallel worker thread."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import List, Callable
|
||||||
|
|
||||||
|
import uvloop
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger("srt.controller")
|
||||||
|
CHECKING_INTERVAL = 5
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
class DataParallelWorkerThread(threading.Thread):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_id: int,
|
||||||
|
request_queue: queue.Queue,
|
||||||
|
detokenizer_port: int,
|
||||||
|
step_func: Callable,
|
||||||
|
):
|
||||||
|
super(DataParallelWorkerThread, self).__init__()
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.request_queue = request_queue
|
||||||
|
self.liveness = True
|
||||||
|
self.request_dependency_delay = global_config.request_dependency_delay
|
||||||
|
|
||||||
|
context = zmq.asyncio.Context()
|
||||||
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||||
|
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
||||||
|
|
||||||
|
self.step = step_func
|
||||||
|
|
||||||
|
async def loop_for_forward(self):
|
||||||
|
while self.liveness:
|
||||||
|
requests = []
|
||||||
|
while not self.request_queue.empty():
|
||||||
|
requests.append(self.request_queue.get())
|
||||||
|
try:
|
||||||
|
out_pyobjs = await self.step(requests)
|
||||||
|
except Exception:
|
||||||
|
for r in requests:
|
||||||
|
self.request_queue.put(r)
|
||||||
|
logger.error(
|
||||||
|
f"Worker thread {self.worker_id}: "
|
||||||
|
f"failed to get back from Model Server\n"
|
||||||
|
f"{get_exception_traceback()}"
|
||||||
|
)
|
||||||
|
self.liveness = False
|
||||||
|
|
||||||
|
for obj in out_pyobjs:
|
||||||
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|
||||||
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
|
if len(out_pyobjs) != 0:
|
||||||
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||||
|
if has_finished:
|
||||||
|
await asyncio.sleep(self.request_dependency_delay)
|
||||||
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
|
|
||||||
|
async def monitoring(self):
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(CHECKING_INTERVAL)
|
||||||
|
# can plug in monitoring logic here
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.create_task(self.monitoring())
|
||||||
|
loop.run_until_complete(self.loop_for_forward())
|
||||||
|
|
||||||
|
|
||||||
|
def start_data_parallel_worker(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
model_overide_args,
|
||||||
|
gpu_ids: List[int],
|
||||||
|
worker_id: int,
|
||||||
|
):
|
||||||
|
model_tp_client = ModelTpClient(
|
||||||
|
gpu_ids,
|
||||||
|
server_args,
|
||||||
|
port_args.model_port_args[worker_id],
|
||||||
|
model_overide_args,
|
||||||
|
)
|
||||||
|
worker_thread = DataParallelWorkerThread(
|
||||||
|
worker_id=worker_id,
|
||||||
|
request_queue=queue.Queue(),
|
||||||
|
detokenizer_port=port_args.detokenizer_port,
|
||||||
|
step_func=model_tp_client.step,
|
||||||
|
)
|
||||||
|
worker_thread.start()
|
||||||
|
return worker_thread
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
"""Meta data for requests and batches"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -5,7 +6,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
|
|
||||||
|
|
||||||
187
python/sglang/srt/managers/controller/manager_multi.py
Normal file
187
python/sglang/srt/managers/controller/manager_multi.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
A controller that manages multiple data parallel workers.
|
||||||
|
Each data parallel worker can manage multiple tensor parallel workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
|
FlushCacheReq,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.controller.dp_worker import (
|
||||||
|
DataParallelWorkerThread,
|
||||||
|
start_data_parallel_worker,
|
||||||
|
)
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger("srt.controller")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBalanceMethod(Enum):
|
||||||
|
ROUND_ROBIN = auto()
|
||||||
|
SHORTEST_QUEUE = auto()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, method: str):
|
||||||
|
method = method.upper()
|
||||||
|
try:
|
||||||
|
return cls[method]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
class Controller:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
load_balance_method: str,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
model_overide_args,
|
||||||
|
):
|
||||||
|
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
||||||
|
self.server_args = server_args
|
||||||
|
self.port_args = port_args
|
||||||
|
|
||||||
|
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
||||||
|
self.round_robin_counter = 0
|
||||||
|
|
||||||
|
self.dispatch_lookup = {
|
||||||
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||||
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||||
|
}
|
||||||
|
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
||||||
|
|
||||||
|
# Init communication
|
||||||
|
context = zmq.asyncio.Context()
|
||||||
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
|
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||||
|
|
||||||
|
# Init status
|
||||||
|
self.recv_reqs = []
|
||||||
|
|
||||||
|
# Start data parallel workers
|
||||||
|
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
||||||
|
tp_size = server_args.tp_size
|
||||||
|
|
||||||
|
def start_dp_worker(i):
|
||||||
|
try:
|
||||||
|
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
||||||
|
worker_thread = start_data_parallel_worker(
|
||||||
|
server_args, port_args, model_overide_args, gpu_ids, i
|
||||||
|
)
|
||||||
|
self.workers[i] = worker_thread
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(server_args.dp_size) as executor:
|
||||||
|
executor.map(start_dp_worker, range(server_args.dp_size))
|
||||||
|
|
||||||
|
def have_any_live_worker(self):
|
||||||
|
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
||||||
|
|
||||||
|
def put_req_to_worker(self, worker_id, req):
|
||||||
|
self.workers[worker_id].request_queue.put(req)
|
||||||
|
|
||||||
|
async def round_robin_scheduler(self, input_requests):
|
||||||
|
available_workers = list(self.workers.keys())
|
||||||
|
for r in input_requests:
|
||||||
|
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
||||||
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||||
|
available_workers
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def shortest_queue_scheduler(self, input_requests):
|
||||||
|
for r in input_requests:
|
||||||
|
worker = min(
|
||||||
|
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
||||||
|
)
|
||||||
|
self.put_req_to_worker(worker, r)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def remove_dead_workers(self):
|
||||||
|
for i in list(self.workers.keys()):
|
||||||
|
worker_thread = self.workers[i]
|
||||||
|
if not worker_thread.liveness:
|
||||||
|
worker_thread.join()
|
||||||
|
# move unsuccessful requests back to the queue
|
||||||
|
while not worker_thread.request_queue.empty():
|
||||||
|
self.recv_reqs.append(worker_thread.request_queue.get())
|
||||||
|
del self.workers[i]
|
||||||
|
logger.info(f"Stale worker {i} removed")
|
||||||
|
|
||||||
|
async def loop_for_forward(self):
|
||||||
|
while True:
|
||||||
|
await self.remove_dead_workers()
|
||||||
|
|
||||||
|
if self.have_any_live_worker():
|
||||||
|
next_step_input = list(self.recv_reqs)
|
||||||
|
self.recv_reqs = []
|
||||||
|
if next_step_input:
|
||||||
|
await self.dispatching(next_step_input)
|
||||||
|
#else:
|
||||||
|
# logger.error("There is no live worker.")
|
||||||
|
|
||||||
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
|
|
||||||
|
async def loop_for_recv_requests(self):
|
||||||
|
while True:
|
||||||
|
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||||
|
if isinstance(recv_req, FlushCacheReq):
|
||||||
|
# TODO(lsyin): apply more specific flushCacheReq
|
||||||
|
for worker_thread in self.workers.values():
|
||||||
|
worker_thread.request_queue.put(recv_req)
|
||||||
|
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
|
self.recv_reqs.append(recv_req)
|
||||||
|
elif isinstance(recv_req, AbortReq):
|
||||||
|
in_queue = False
|
||||||
|
for i, req in enumerate(self.recv_reqs):
|
||||||
|
if req.rid == recv_req.rid:
|
||||||
|
self.recv_reqs[i] = recv_req
|
||||||
|
in_queue = True
|
||||||
|
break
|
||||||
|
if not in_queue:
|
||||||
|
# Send abort req to all TP groups
|
||||||
|
for worker in list(self.workers.keys()):
|
||||||
|
self.put_req_to_worker(worker, recv_req)
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid object: {recv_req}")
|
||||||
|
|
||||||
|
|
||||||
|
def start_controller_process(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
pipe_writer,
|
||||||
|
model_overide_args=None,
|
||||||
|
):
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
|
format="%(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
controller = Controller(
|
||||||
|
server_args.load_balance_method, server_args, port_args, model_overide_args
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pipe_writer.send(get_exception_traceback())
|
||||||
|
raise
|
||||||
|
|
||||||
|
pipe_writer.send("init ok")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.create_task(controller.loop_for_recv_requests())
|
||||||
|
loop.run_until_complete(controller.loop_for_forward())
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -6,15 +7,15 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
class RouterManager:
|
class ControllerSingle:
|
||||||
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
||||||
# Init communication
|
# Init communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
@@ -30,7 +31,7 @@ class RouterManager:
|
|||||||
self.recv_reqs = []
|
self.recv_reqs = []
|
||||||
|
|
||||||
# Init some configs
|
# Init some configs
|
||||||
self.request_dependency_time = global_config.request_dependency_time
|
self.request_dependency_delay = global_config.request_dependency_delay
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
async def loop_for_forward(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -46,12 +47,12 @@ class RouterManager:
|
|||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||||
if has_finished:
|
if has_finished:
|
||||||
if self.request_dependency_time > 0:
|
if self.request_dependency_delay > 0:
|
||||||
slept = True
|
slept = True
|
||||||
await asyncio.sleep(self.request_dependency_time)
|
await asyncio.sleep(self.request_dependency_delay)
|
||||||
|
|
||||||
if not slept:
|
if not slept:
|
||||||
await asyncio.sleep(0.0006)
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
|
|
||||||
async def loop_for_recv_requests(self):
|
async def loop_for_recv_requests(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -59,7 +60,7 @@ class RouterManager:
|
|||||||
self.recv_reqs.append(recv_req)
|
self.recv_reqs.append(recv_req)
|
||||||
|
|
||||||
|
|
||||||
def start_router_process(
|
def start_controller_process(
|
||||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
||||||
):
|
):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -68,8 +69,13 @@ def start_router_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
model_client = ModelTpClient(
|
||||||
router = RouterManager(model_client, port_args)
|
list(range(server_args.tp_size)),
|
||||||
|
server_args,
|
||||||
|
port_args.model_port_args[0],
|
||||||
|
model_overide_args,
|
||||||
|
)
|
||||||
|
controller = ControllerSingle(model_client, port_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
@@ -78,5 +84,5 @@ def start_router_process(
|
|||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.create_task(router.loop_for_recv_requests())
|
loop.create_task(controller.loop_for_recv_requests())
|
||||||
loop.run_until_complete(router.loop_for_forward())
|
loop.run_until_complete(controller.loop_for_forward())
|
||||||
@@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("srt.model_runner")
|
||||||
|
|
||||||
# for server args in model endpoints
|
# for server args in model endpoints
|
||||||
global_server_args_dict = {}
|
global_server_args_dict = {}
|
||||||
@@ -215,14 +215,16 @@ class ModelRunner:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config,
|
model_config,
|
||||||
mem_fraction_static,
|
mem_fraction_static: float,
|
||||||
tp_rank,
|
gpu_id: int,
|
||||||
tp_size,
|
tp_rank: int,
|
||||||
nccl_port,
|
tp_size: int,
|
||||||
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.mem_fraction_static = mem_fraction_static
|
self.mem_fraction_static = mem_fraction_static
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.nccl_port = nccl_port
|
self.nccl_port = nccl_port
|
||||||
@@ -235,9 +237,9 @@ class ModelRunner:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||||
torch.cuda.set_device(self.tp_rank)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
@@ -245,22 +247,26 @@ class ModelRunner:
|
|||||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
total_gpu_memory = get_available_gpu_memory(
|
||||||
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||||
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
||||||
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
raise ValueError(
|
||||||
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||||
|
)
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.init_memory_pool(total_gpu_memory)
|
self.init_memory_pool(total_gpu_memory)
|
||||||
|
|
||||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
logger.info(
|
||||||
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
||||||
|
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
@@ -286,12 +292,16 @@ class ModelRunner:
|
|||||||
parallel_config=None,
|
parallel_config=None,
|
||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
)
|
)
|
||||||
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
logger.info(
|
||||||
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||||
f"Type={type(self.model).__name__}. "
|
f"Type={type(self.model).__name__}. "
|
||||||
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||||
@@ -306,7 +316,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
if self.max_total_num_tokens <= 0:
|
if self.max_total_num_tokens <= 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
"Not enought memory. Please try to increase --mem-fraction-static."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
@@ -320,6 +330,10 @@ class ModelRunner:
|
|||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
||||||
|
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_prefill(self, batch: Batch):
|
def forward_prefill(self, batch: Batch):
|
||||||
@@ -424,8 +438,8 @@ def import_model_classes():
|
|||||||
if hasattr(module, "EntryClass"):
|
if hasattr(module, "EntryClass"):
|
||||||
entry = module.EntryClass
|
entry = module.EntryClass
|
||||||
if isinstance(entry, list): # To support multiple model classes in one module
|
if isinstance(entry, list): # To support multiple model classes in one module
|
||||||
for cls in entry:
|
for tmp in entry:
|
||||||
model_arch_name_to_cls[cls.__name__] = cls
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||||
else:
|
else:
|
||||||
model_arch_name_to_cls[entry.__name__] = entry
|
model_arch_name_to_cls[entry.__name__] = entry
|
||||||
return model_arch_name_to_cls
|
return model_arch_name_to_cls
|
||||||
@@ -2,7 +2,7 @@ import random
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class ScheduleHeuristic:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
schedule_heuristic,
|
schedule_heuristic,
|
||||||
@@ -1,20 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
from rpyc.utils.server import ThreadedServer
|
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm.logger import _default_handler as vllm_default_logger
|
|
||||||
except ImportError:
|
|
||||||
from vllm.logger import logger as vllm_default_logger
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
@@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
|
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.managers.router.scheduler import Scheduler
|
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
|
start_rpyc_process,
|
||||||
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("srt.model_tp")
|
||||||
vllm_default_logger.setLevel(logging.WARN)
|
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
|
||||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcServer:
|
class ModelTpServer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
model_port_args: ModelPortArgs,
|
||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args,
|
||||||
):
|
):
|
||||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||||
|
suppress_other_loggers()
|
||||||
|
|
||||||
# Copy arguments
|
# Copy arguments
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
|
self.dp_size = server_args.dp_size
|
||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
|
|
||||||
@@ -68,16 +64,16 @@ class ModelRpcServer:
|
|||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For model end global settings
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=model_port_args.nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_multimodal_model(server_args.model_path):
|
if is_multimodal_model(server_args.model_path):
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
@@ -95,21 +91,21 @@ class ModelRpcServer:
|
|||||||
self.max_prefill_tokens = max(
|
self.max_prefill_tokens = max(
|
||||||
self.model_config.context_len,
|
self.model_config.context_len,
|
||||||
(
|
(
|
||||||
self.max_total_num_tokens // 6
|
min(self.max_total_num_tokens // 6, 65536)
|
||||||
if server_args.max_prefill_tokens is None
|
if server_args.max_prefill_tokens is None
|
||||||
else server_args.max_prefill_tokens
|
else server_args.max_prefill_tokens
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.max_running_requests = (self.max_total_num_tokens // 2
|
self.max_running_requests = (self.max_total_num_tokens // 2
|
||||||
if server_args.max_running_requests is None else server_args.max_running_requests)
|
if server_args.max_running_requests is None else server_args.max_running_requests)
|
||||||
|
|
||||||
self.int_token_logit_bias = torch.tensor(
|
self.int_token_logit_bias = torch.tensor(
|
||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
)
|
)
|
||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print info
|
||||||
logger.info(f"[rank={self.tp_rank}] "
|
logger.info(
|
||||||
|
f"[gpu_id={self.gpu_id}] "
|
||||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||||
f"context_len={self.model_config.context_len}, "
|
f"context_len={self.model_config.context_len}, "
|
||||||
@@ -124,7 +120,7 @@ class ModelRpcServer:
|
|||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
)
|
)
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.scheduler = Scheduler(
|
self.scheduler = ScheduleHeuristic(
|
||||||
self.schedule_heuristic,
|
self.schedule_heuristic,
|
||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
@@ -170,7 +166,7 @@ class ModelRpcServer:
|
|||||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if self.tp_size != 1:
|
if self.tp_size * self.dp_size != 1:
|
||||||
recv_reqs = obtain(recv_reqs)
|
recv_reqs = obtain(recv_reqs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -188,7 +184,7 @@ class ModelRpcServer:
|
|||||||
# Forward
|
# Forward
|
||||||
self.forward_step()
|
self.forward_step()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
|
||||||
|
|
||||||
# Return results
|
# Return results
|
||||||
ret = self.out_pyobjs
|
ret = self.out_pyobjs
|
||||||
@@ -224,16 +220,17 @@ class ModelRpcServer:
|
|||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
throuhgput = self.num_generated_tokens / (
|
throughput = self.num_generated_tokens / (
|
||||||
time.time() - self.last_stats_tic
|
time.time() - self.last_stats_tic
|
||||||
)
|
)
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_stats_tic = time.time()
|
self.last_stats_tic = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
f"[gpu_id={self.gpu_id}] "
|
||||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||||
f"#token: {num_used}, "
|
f"#token: {num_used}, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"gen throughput (token/s): {throuhgput:.2f}, "
|
f"gen throughput (token/s): {throughput:.2f}, "
|
||||||
f"#queue-req: {len(self.forward_queue)}"
|
f"#queue-req: {len(self.forward_queue)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -724,20 +721,30 @@ class ModelRpcServer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcService(rpyc.Service):
|
class ModelTpService(rpyc.Service):
|
||||||
exposed_ModelRpcServer = ModelRpcServer
|
exposed_ModelTpServer = ModelTpServer
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcClient:
|
class ModelTpClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
self,
|
||||||
|
gpu_ids: List[int],
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_port_args: ModelPortArgs,
|
||||||
|
model_overide_args,
|
||||||
):
|
):
|
||||||
tp_size = server_args.tp_size
|
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||||
|
self.tp_size = server_args.tp_size
|
||||||
|
|
||||||
if tp_size == 1:
|
if self.tp_size * server_args.dp_size == 1:
|
||||||
# Init model
|
# Init model
|
||||||
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
assert len(gpu_ids) == 1
|
||||||
0, server_args, port_args, model_overide_args
|
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||||
|
0,
|
||||||
|
gpu_ids[0],
|
||||||
|
server_args,
|
||||||
|
model_port_args,
|
||||||
|
model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
@@ -749,19 +756,26 @@ class ModelRpcClient:
|
|||||||
|
|
||||||
self.step = async_wrap(self.model_server.exposed_step)
|
self.step = async_wrap(self.model_server.exposed_step)
|
||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(tp_size) as executor:
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||||
# Launch model processes
|
# Launch model processes
|
||||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
rets = executor.map(
|
||||||
self.remote_services = [x[0] for x in rets]
|
lambda args: start_rpyc_process(*args),
|
||||||
|
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||||
|
)
|
||||||
|
self.model_services = [x[0] for x in rets]
|
||||||
self.procs = [x[1] for x in rets]
|
self.procs = [x[1] for x in rets]
|
||||||
|
|
||||||
# Init model
|
# Init model
|
||||||
def init_model(i):
|
def init_model(i):
|
||||||
return self.remote_services[i].ModelRpcServer(
|
return self.model_services[i].ModelTpServer(
|
||||||
i, server_args, port_args, model_overide_args
|
gpu_ids[i],
|
||||||
|
i,
|
||||||
|
server_args,
|
||||||
|
model_port_args,
|
||||||
|
model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_servers = executor.map(init_model, range(tp_size))
|
self.model_servers = executor.map(init_model, range(self.tp_size))
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
def async_wrap(func_name):
|
def async_wrap(func_name):
|
||||||
@@ -775,44 +789,3 @@ class ModelRpcClient:
|
|||||||
return _func
|
return _func
|
||||||
|
|
||||||
self.step = async_wrap("step")
|
self.step = async_wrap("step")
|
||||||
|
|
||||||
|
|
||||||
def _init_service(port):
|
|
||||||
t = ThreadedServer(
|
|
||||||
ModelRpcService(),
|
|
||||||
port=port,
|
|
||||||
protocol_config={
|
|
||||||
"allow_public_attrs": True,
|
|
||||||
"allow_pickle": True,
|
|
||||||
"sync_request_timeout": 3600,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
def start_model_process(port):
|
|
||||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
|
||||||
proc.start()
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
repeat_count = 0
|
|
||||||
while repeat_count < 20:
|
|
||||||
try:
|
|
||||||
con = rpyc.connect(
|
|
||||||
"localhost",
|
|
||||||
port,
|
|
||||||
config={
|
|
||||||
"allow_public_attrs": True,
|
|
||||||
"allow_pickle": True,
|
|
||||||
"sync_request_timeout": 3600,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except ConnectionRefusedError:
|
|
||||||
time.sleep(1)
|
|
||||||
repeat_count += 1
|
|
||||||
if repeat_count == 20:
|
|
||||||
raise RuntimeError("init rpc env error!")
|
|
||||||
|
|
||||||
assert proc.is_alive()
|
|
||||||
return con.root, proc
|
|
||||||
@@ -27,7 +27,6 @@ class GenerateReqInput:
|
|||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class DbrxRouter(nn.Module):
|
class DbrxRouter(nn.Module):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from vllm.utils import print_warning_once
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.fused_moe import fused_moe
|
from sglang.srt.layers.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
use_fused = True
|
use_fused = True
|
||||||
|
|||||||
@@ -4,9 +4,13 @@
|
|||||||
from typing import Any, Dict, Optional, Tuple, Iterable
|
from typing import Any, Dict, Optional, Tuple, Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
@@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class MixtralMLP(nn.Module):
|
class MixtralMLP(nn.Module):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class StablelmMLP(nn.Module):
|
class StablelmMLP(nn.Module):
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import List, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache
|
|||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.router.manager import start_router_process
|
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
||||||
|
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api_adapter import (
|
from sglang.srt.openai_api_adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
API_KEY_HEADER_NAME,
|
API_KEY_HEADER_NAME,
|
||||||
APIKeyValidatorMiddleware,
|
APIKeyValidatorMiddleware,
|
||||||
@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
server_args.port, server_args.additional_ports, server_args.tp_size
|
server_args.port,
|
||||||
|
server_args.additional_ports,
|
||||||
|
server_args.tp_size,
|
||||||
|
server_args.dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init local models port args
|
||||||
|
ports = server_args.additional_ports
|
||||||
|
tp = server_args.tp_size
|
||||||
|
model_port_args = []
|
||||||
|
for i in range(server_args.dp_size):
|
||||||
|
model_port_args.append(
|
||||||
|
ModelPortArgs(
|
||||||
|
nccl_port=ports[3 + i * (tp + 1)],
|
||||||
|
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
port_args = PortArgs(
|
port_args = PortArgs(
|
||||||
tokenizer_port=server_args.additional_ports[0],
|
tokenizer_port=ports[0],
|
||||||
router_port=server_args.additional_ports[1],
|
router_port=ports[1],
|
||||||
detokenizer_port=server_args.additional_ports[2],
|
detokenizer_port=ports[2],
|
||||||
nccl_port=server_args.additional_ports[3],
|
model_port_args=model_port_args,
|
||||||
model_rpc_ports=server_args.additional_ports[4:],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
|
if server_args.dp_size == 1:
|
||||||
|
start_process = start_controller_process_single
|
||||||
|
else:
|
||||||
|
start_process = start_controller_process_multi
|
||||||
proc_router = mp.Process(
|
proc_router = mp.Process(
|
||||||
target=start_router_process,
|
target=start_process,
|
||||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||||
)
|
)
|
||||||
proc_router.start()
|
proc_router.start()
|
||||||
@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_evel: str = "error",
|
log_level: str = "error",
|
||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args: Optional[dict] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""See the arguments in server_args.py::ServerArgs"""
|
"""See the arguments in server_args.py::ServerArgs"""
|
||||||
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
|
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||||
|
|
||||||
# Pre-allocate ports
|
# Pre-allocate ports
|
||||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||||
self.server_args.port,
|
self.server_args.port,
|
||||||
self.server_args.additional_ports,
|
self.server_args.additional_ports,
|
||||||
self.server_args.tp_size,
|
self.server_args.tp_size,
|
||||||
|
self.server_args.dp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.url = self.server_args.url()
|
self.url = self.server_args.url()
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ class ServerArgs:
|
|||||||
# Other
|
# Other
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
|
|
||||||
|
# Data parallelism
|
||||||
|
dp_size: int = 1
|
||||||
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
enable_flashinfer: bool = False
|
enable_flashinfer: bool = False
|
||||||
attention_reduce_in_fp32: bool = False
|
attention_reduce_in_fp32: bool = False
|
||||||
@@ -226,6 +230,24 @@ class ServerArgs:
|
|||||||
help="Set API key of the server",
|
help="Set API key of the server",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Data parallelism
|
||||||
|
parser.add_argument(
|
||||||
|
"--dp-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.dp_size,
|
||||||
|
help="Data parallelism size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-balance-method",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.load_balance_method,
|
||||||
|
help="Load balancing strategy for data parallelism.",
|
||||||
|
choices=[
|
||||||
|
"round_robin",
|
||||||
|
"shortest_queue",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer",
|
"--enable-flashinfer",
|
||||||
@@ -271,10 +293,15 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ModelPortArgs:
|
||||||
|
nccl_port: int
|
||||||
|
model_tp_ports: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
tokenizer_port: int
|
tokenizer_port: int
|
||||||
router_port: int
|
router_port: int
|
||||||
detokenizer_port: int
|
detokenizer_port: int
|
||||||
nccl_port: int
|
model_port_args: List[ModelPortArgs]
|
||||||
model_rpc_ports: List[int]
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import multiprocessing
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -12,12 +13,14 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
from rpyc.utils.server import ThreadedServer
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
|
|||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
|
"""Set the random seed for all libraries."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
def is_port_available(port):
|
def is_port_available(port):
|
||||||
|
"""Return whether a port is available."""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
try:
|
try:
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
@@ -142,7 +147,9 @@ def allocate_init_ports(
|
|||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
additional_ports: Optional[List[int]] = None,
|
additional_ports: Optional[List[int]] = None,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
|
dp_size: int = 1,
|
||||||
):
|
):
|
||||||
|
"""Allocate ports for all connections."""
|
||||||
if additional_ports:
|
if additional_ports:
|
||||||
ret_ports = [port] + additional_ports
|
ret_ports = [port] + additional_ports
|
||||||
else:
|
else:
|
||||||
@@ -151,20 +158,23 @@ def allocate_init_ports(
|
|||||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||||
|
|
||||||
while len(ret_ports) < 5 + tp_size:
|
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
||||||
|
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
||||||
|
while len(ret_ports) < num_ports_needed:
|
||||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||||
ret_ports.append(cur_port)
|
ret_ports.append(cur_port)
|
||||||
cur_port += 1
|
cur_port += 1
|
||||||
|
|
||||||
if port and ret_ports[0] != port:
|
if port is not None and ret_ports[0] != port:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret_ports[0], ret_ports[1:]
|
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||||
|
|
||||||
|
|
||||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||||
|
"""Get the logit bias for integer-only tokens."""
|
||||||
# a bug when model's vocab size > tokenizer.vocab_size
|
# a bug when model's vocab size > tokenizer.vocab_size
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||||
@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
|
|||||||
if int(triton.__version__.split(".")[0]) >= 3:
|
if int(triton.__version__.split(".")[0]) >= 3:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if dist.is_initialized():
|
gpu_id = torch.cuda.current_device()
|
||||||
rank = dist.get_rank()
|
kernels = kernel.cache[gpu_id].values()
|
||||||
else:
|
|
||||||
rank = 0
|
|
||||||
|
|
||||||
kernels = kernel.cache[rank].values()
|
|
||||||
kernel = next(iter(kernels))
|
kernel = next(iter(kernels))
|
||||||
|
|
||||||
# Different trition versions use different low-level names
|
# Different trition versions use different low-level names
|
||||||
@@ -363,6 +369,63 @@ def load_image(image_file):
|
|||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
|
def init_rpyc_service(service: rpyc.Service, port: int):
|
||||||
|
t = ThreadedServer(
|
||||||
|
service=service,
|
||||||
|
port=port,
|
||||||
|
protocol_config={
|
||||||
|
"allow_public_attrs": True,
|
||||||
|
"allow_pickle": True,
|
||||||
|
"sync_request_timeout": 3600
|
||||||
|
},
|
||||||
|
)
|
||||||
|
t.logger.setLevel(logging.WARN)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def connect_to_rpyc_service(port, host="localhost"):
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
repeat_count = 0
|
||||||
|
while repeat_count < 20:
|
||||||
|
try:
|
||||||
|
con = rpyc.connect(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
config={
|
||||||
|
"allow_public_attrs": True,
|
||||||
|
"allow_pickle": True,
|
||||||
|
"sync_request_timeout": 3600
|
||||||
|
},
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
time.sleep(1)
|
||||||
|
repeat_count += 1
|
||||||
|
if repeat_count == 20:
|
||||||
|
raise RuntimeError("init rpc env error!")
|
||||||
|
|
||||||
|
return con.root
|
||||||
|
|
||||||
|
|
||||||
|
def start_rpyc_process(service: rpyc.Service, port: int):
|
||||||
|
# Return the proxy and the process
|
||||||
|
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
||||||
|
proc.start()
|
||||||
|
proxy = connect_to_rpyc_service(port)
|
||||||
|
assert proc.is_alive()
|
||||||
|
return proxy, proc
|
||||||
|
|
||||||
|
|
||||||
|
def suppress_other_loggers():
|
||||||
|
from vllm.logger import logger as vllm_default_logger
|
||||||
|
|
||||||
|
vllm_default_logger.setLevel(logging.WARN)
|
||||||
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||||
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||||
|
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def assert_pkg_version(pkg: str, min_version: str):
|
def assert_pkg_version(pkg: str, min_version: str):
|
||||||
try:
|
try:
|
||||||
installed_version = version(pkg)
|
installed_version = version(pkg)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner
|
from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.utils import load_image
|
from sglang.srt.utils import load_image
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user