Remove the dependency of rpyc (#646)

This commit is contained in:
Mingyi
2024-07-18 02:13:54 -07:00
committed by GitHub
parent d93388da3e
commit d774acad5c
11 changed files with 294 additions and 542 deletions

View File

@@ -3,7 +3,6 @@
import base64
import fcntl
import logging
import multiprocessing
import os
import random
import socket
@@ -16,12 +15,10 @@ from typing import List, Optional
import numpy as np
import psutil
import requests
import rpyc
import torch
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
@@ -148,7 +145,6 @@ def is_port_available(port):
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
@@ -160,8 +156,8 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
@@ -371,49 +367,6 @@ def load_image(image_file):
return image, image_size
def connect_rpyc_service(host, port):
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
break
except ConnectionRefusedError as e:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError(f"Connect rpyc error: {e}")
return con.root
def start_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def start_rpyc_service_process(service: rpyc.Service, port: int):
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
proc.start()
return proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger