Remove the dependency of rpyc (#646)

This commit is contained in:
Mingyi
2024-07-18 02:13:54 -07:00
committed by GitHub
parent d93388da3e
commit d774acad5c
11 changed files with 294 additions and 542 deletions

View File

@@ -1,15 +1,14 @@
"""A tensor parallel worker."""
import asyncio
import logging
import multiprocessing
import pickle
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
import rpyc
import torch
from rpyc.utils.classic import obtain
import torch.distributed as dist
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_service_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
@@ -52,10 +49,9 @@ class ModelTpServer:
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
nccl_port: int,
model_overide_args: dict,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments
@@ -79,7 +75,7 @@ class ModelTpServer:
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=model_port_args.nccl_port,
nccl_port=nccl_port,
server_args=server_args,
)
@@ -178,9 +174,6 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
@@ -425,12 +418,6 @@ class ModelTpServer:
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch = Batch.init_new(
@@ -733,87 +720,74 @@ class ModelTpServer:
break
class ModelTpService(rpyc.Service):
exposed_ModelTpServer = ModelTpServer
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
):
"""Run a tensor parallel server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
nccl_port,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
class ModelTpClient:
def __init__(
self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
)
proc.start()
procs.append(proc)
if self.tp_size * server_args.dp_size == 1:
# Init model
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
gpu_ids[0],
0,
server_args,
model_port_args,
model_overide_args,
)
return procs
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
self.step = async_wrap(self.model_server.exposed_step)
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
if server_args.nnodes == 1:
self.procs = list(
executor.map(
lambda args: start_rpyc_service_process(*args),
[
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
)
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [
(ip, port)
for ip, port in zip(
model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
self.model_services = list(
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
# Init model
def init_model(i):
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
if size == 0:
return []
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data