Improve tensor parallel performance (#625)
Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
@@ -2,11 +2,10 @@
|
||||
|
||||
- `backend`: Various backends for the language interpreter.
|
||||
- `lang`: The frontend language.
|
||||
- `srt`: The runtime for running local models.
|
||||
- `srt`: The serving engine for running local models. (SRT = SGLang Runtime).
|
||||
- `test`: Test utilities.
|
||||
- `api.py`: Public API.
|
||||
- `bench_latency.py`: Benchmark utilities.
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point of launching local server.
|
||||
- `utils.py`: Common utilities.
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
|
||||
|
||||
|
||||
class Controller:
|
||||
"""A controller that manages multiple data parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_balance_method: str,
|
||||
@@ -183,9 +185,11 @@ def start_controller_process(
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
|
||||
@@ -1,28 +1,104 @@
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import uvloop
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
|
||||
|
||||
def run_tp_server(
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Run a tp server."""
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
while True:
|
||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||
model_server.exposed_step(recv_reqs)
|
||||
except Exception:
|
||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
|
||||
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
||||
model_port_args, model_overide_args):
|
||||
"""Launch multiple tp servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(target=run_tp_server, args=(
|
||||
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
||||
))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
return procs
|
||||
|
||||
|
||||
def broadcast_recv_input(data, rank, dist_group):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
|
||||
class ControllerSingle:
|
||||
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
context = zmq.Context(2)
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
@@ -31,44 +107,52 @@ class ControllerSingle:
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
|
||||
# Init status
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
# Init model server
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
|
||||
# Init some configs
|
||||
self.request_dependency_delay = global_config.request_dependency_delay
|
||||
# Launch other tp ranks
|
||||
if tp_size_local > 1:
|
||||
tp_rank_range = range(1, tp_size_local)
|
||||
self.tp_procs = launch_tp_servers(
|
||||
gpu_ids, tp_rank_range, server_args,
|
||||
port_args.model_port_args[0], model_overide_args)
|
||||
|
||||
async def loop_for_forward(self):
|
||||
# Launch tp rank 0
|
||||
self.tp_server = ModelTpServer(
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
model_overide_args,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||
|
||||
def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
self.recv_reqs = []
|
||||
out_pyobjs = await self.model_client.step(next_step_input)
|
||||
recv_reqs = self.recv_requests()
|
||||
|
||||
if self.server_args.tp_size > 1:
|
||||
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
||||
|
||||
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
slept = False
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any(
|
||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
||||
)
|
||||
if has_finished:
|
||||
if self.request_dependency_delay > 0:
|
||||
slept = True
|
||||
await asyncio.sleep(self.request_dependency_delay)
|
||||
|
||||
if not slept:
|
||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
def recv_requests(self):
|
||||
recv_reqs = []
|
||||
while True:
|
||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||
self.recv_reqs.append(recv_req)
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
recv_reqs.append(recv_req)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
@@ -76,27 +160,18 @@ def start_controller_process(
|
||||
)
|
||||
|
||||
try:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
model_client = ModelTpClient(
|
||||
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
model_overide_args,
|
||||
)
|
||||
controller = ControllerSingle(model_client, port_args)
|
||||
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
try:
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
controller.loop_for_forward()
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
for t in controller.tp_procs:
|
||||
os.kill(t.pid, 9)
|
||||
kill_parent_process()
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
@@ -75,6 +75,7 @@ class ModelRunner:
|
||||
distributed_init_method=nccl_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
self.tp_group = get_tp_group()
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ class ModelTpServer:
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
suppress_other_loggers()
|
||||
@@ -178,7 +178,7 @@ class ModelTpServer:
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size * self.dp_size != 1:
|
||||
if not isinstance(recv_reqs, list):
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
@@ -206,11 +206,11 @@ class ModelTpServer:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
new_batch = self.get_new_fill_batch()
|
||||
new_batch = self.get_new_prefill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run a new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
self.cache_filled_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
@@ -219,7 +219,7 @@ class ModelTpServer:
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
else:
|
||||
# Run decode batch
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
@@ -312,7 +312,7 @@ class ModelTpServer:
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
@@ -436,7 +436,7 @@ class ModelTpServer:
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_fill_batch(self, batch: Batch):
|
||||
def forward_prefill_batch(self, batch: Batch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
@@ -746,8 +746,8 @@ class ModelTpClient:
|
||||
# Init model
|
||||
assert len(gpu_ids) == 1
|
||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||
0,
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
|
||||
@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
|
||||
start_controller_process as start_controller_process_multi,
|
||||
)
|
||||
from sglang.srt.managers.controller.manager_single import (
|
||||
launch_tp_servers,
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
receive_addrs,
|
||||
send_addrs_to_rank_0,
|
||||
start_rpyc_service_process,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
model_port_args=model_port_args,
|
||||
)
|
||||
|
||||
# TODO multi-node dp is not supported
|
||||
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
||||
# Handle multi-node tp
|
||||
if server_args.nnodes > 1:
|
||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||
|
||||
if server_args.node_rank != 0:
|
||||
send_addrs_to_rank_0(model_port_args[0], server_args)
|
||||
else:
|
||||
receive_addrs(model_port_args[0], server_args)
|
||||
for i in range(tp_size_local):
|
||||
start_rpyc_service_process(
|
||||
ModelTpService, model_port_args[0].model_tp_ports[i]
|
||||
)
|
||||
if server_args.node_rank != 0:
|
||||
logger.info(
|
||||
f"[node_rank={server_args.node_rank}]: Listen for connections..."
|
||||
)
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local))
|
||||
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
||||
port_args.model_port_args[0], model_overide_args)
|
||||
while True:
|
||||
pass
|
||||
|
||||
|
||||
@@ -67,10 +67,12 @@ class ServerArgs:
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 8:
|
||||
if self.tp_size >= 16:
|
||||
self.mem_fraction_static = 0.74
|
||||
elif self.tp_size >= 8:
|
||||
self.mem_fraction_static = 0.78
|
||||
elif self.tp_size >= 4:
|
||||
self.mem_fraction_static = 0.80
|
||||
self.mem_fraction_static = 0.82
|
||||
elif self.tp_size >= 2:
|
||||
self.mem_fraction_static = 0.85
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user