Remove the dependency of rpyc (#646)
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
@@ -16,12 +15,10 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import rpyc
|
||||
import torch
|
||||
import triton
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,7 +145,6 @@ def is_port_available(port):
|
||||
def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
"""Allocate ports for all connections."""
|
||||
@@ -160,8 +156,8 @@ def allocate_init_ports(
|
||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
||||
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
||||
num_ports_needed = 4 + dp_size
|
||||
while len(ret_ports) < num_ports_needed:
|
||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||
ret_ports.append(cur_port)
|
||||
@@ -371,49 +367,6 @@ def load_image(image_file):
|
||||
return image, image_size
|
||||
|
||||
|
||||
def connect_rpyc_service(host, port):
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
host,
|
||||
port,
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError as e:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError(f"Connect rpyc error: {e}")
|
||||
|
||||
return con.root
|
||||
|
||||
|
||||
def start_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
||||
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
|
||||
Reference in New Issue
Block a user