Support data parallelism (static) (#480)

Co-authored-by: Ying Sheng <ying.sheng@databricks.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Ying Sheng
2024-05-27 21:24:10 -07:00
committed by GitHub
parent 565d727409
commit 0463f7fb52
32 changed files with 580 additions and 181 deletions

View File

@@ -1,6 +1,7 @@
"""Common utilities."""
import base64
import multiprocessing
import logging
import os
import random
@@ -12,12 +13,14 @@ from typing import List, Optional
import numpy as np
import requests
import rpyc
import torch
import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
import torch.distributed as dist
logger = logging.getLogger(__name__)
@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -142,7 +147,9 @@ def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
@@ -151,20 +158,23 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
while len(ret_ports) < 5 + tp_size:
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port and ret_ports[0] != port:
if port is not None and ret_ports[0] != port:
logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:]
return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
if int(triton.__version__.split(".")[0]) >= 3:
return None
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
kernels = kernel.cache[rank].values()
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
@@ -363,6 +369,63 @@ def load_image(image_file):
return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
return con.root
def start_rpyc_process(service: rpyc.Service, port: int):
# Return the proxy and the process
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
proc.start()
proxy = connect_to_rpyc_service(port)
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str):
try:
installed_version = version(pkg)
@@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
return response