Remove the dependency of rpyc (#646)
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
"""A tensor parallel worker."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
@@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
connect_rpyc_service,
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
start_rpyc_service_process,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -52,10 +49,9 @@ class ModelTpServer:
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
nccl_port: int,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
suppress_other_loggers()
|
||||
|
||||
# Copy arguments
|
||||
@@ -79,7 +75,7 @@ class ModelTpServer:
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=model_port_args.nccl_port,
|
||||
nccl_port=nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
@@ -178,9 +174,6 @@ class ModelTpServer:
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if not isinstance(recv_reqs, list):
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
@@ -425,12 +418,6 @@ class ModelTpServer:
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
# )
|
||||
|
||||
# Return the new batch
|
||||
new_batch = Batch.init_new(
|
||||
@@ -733,87 +720,74 @@ class ModelTpServer:
|
||||
break
|
||||
|
||||
|
||||
class ModelTpService(rpyc.Service):
|
||||
exposed_ModelTpServer = ModelTpServer
|
||||
def run_tp_server(
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
"""Run a tensor parallel server."""
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
server_args,
|
||||
nccl_port,
|
||||
model_overide_args,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
while True:
|
||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||
model_server.exposed_step(recv_reqs)
|
||||
except Exception:
|
||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
|
||||
class ModelTpClient:
|
||||
def __init__(
|
||||
self,
|
||||
gpu_ids: List[int],
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
self.tp_size = server_args.tp_size
|
||||
def launch_tp_servers(
|
||||
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
|
||||
):
|
||||
"""Launch multiple tensor parallel servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_tp_server,
|
||||
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
if self.tp_size * server_args.dp_size == 1:
|
||||
# Init model
|
||||
assert len(gpu_ids) == 1
|
||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
return procs
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(f):
|
||||
async def _func(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
def broadcast_recv_input(data, rank, dist_group):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
self.step = async_wrap(self.model_server.exposed_step)
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||
# Launch model processes
|
||||
if server_args.nnodes == 1:
|
||||
self.procs = list(
|
||||
executor.map(
|
||||
lambda args: start_rpyc_service_process(*args),
|
||||
[
|
||||
(ModelTpService, p)
|
||||
for p in model_port_args.model_tp_ports
|
||||
],
|
||||
)
|
||||
)
|
||||
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
||||
else:
|
||||
addrs = [
|
||||
(ip, port)
|
||||
for ip, port in zip(
|
||||
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
||||
)
|
||||
]
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
self.model_services = list(
|
||||
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
||||
)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.model_services[i].ModelTpServer(
|
||||
gpu_ids[i],
|
||||
i,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
)
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
||||
|
||||
async def _func(*args, **kwargs):
|
||||
tasks = [f(*args, **kwargs) for f in fs]
|
||||
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
||||
return obtain(tasks[0].value)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user