Support data parallelism (static) (#480)
Co-authored-by: Ying Sheng <ying.sheng@databricks.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import multiprocessing
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -12,12 +13,14 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import rpyc
|
||||
import torch
|
||||
import triton
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
"""Set the random seed for all libraries."""
|
||||
random.seed(seed)
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def is_port_available(port):
|
||||
"""Return whether a port is available."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
@@ -142,7 +147,9 @@ def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
"""Allocate ports for all connections."""
|
||||
if additional_ports:
|
||||
ret_ports = [port] + additional_ports
|
||||
else:
|
||||
@@ -151,20 +158,23 @@ def allocate_init_ports(
|
||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||
|
||||
while len(ret_ports) < 5 + tp_size:
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
||||
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
||||
while len(ret_ports) < num_ports_needed:
|
||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||
ret_ports.append(cur_port)
|
||||
cur_port += 1
|
||||
|
||||
if port and ret_ports[0] != port:
|
||||
if port is not None and ret_ports[0] != port:
|
||||
logger.warn(
|
||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||
)
|
||||
|
||||
return ret_ports[0], ret_ports[1:]
|
||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
"""Get the logit bias for integer-only tokens."""
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
|
||||
if int(triton.__version__.split(".")[0]) >= 3:
|
||||
return None
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
kernels = kernel.cache[rank].values()
|
||||
gpu_id = torch.cuda.current_device()
|
||||
kernels = kernel.cache[gpu_id].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
@@ -363,6 +369,63 @@ def load_image(image_file):
|
||||
return image, image_size
|
||||
|
||||
|
||||
def init_rpyc_service(service: rpyc.Service, port: int):
|
||||
t = ThreadedServer(
|
||||
service=service,
|
||||
port=port,
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
},
|
||||
)
|
||||
t.logger.setLevel(logging.WARN)
|
||||
t.start()
|
||||
|
||||
|
||||
def connect_to_rpyc_service(port, host="localhost"):
|
||||
time.sleep(1)
|
||||
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
host,
|
||||
port,
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 3600
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
|
||||
return con.root
|
||||
|
||||
|
||||
def start_rpyc_process(service: rpyc.Service, port: int):
|
||||
# Return the proxy and the process
|
||||
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
|
||||
proc.start()
|
||||
proxy = connect_to_rpyc_service(port)
|
||||
assert proc.is_alive()
|
||||
return proxy, proc
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str):
|
||||
try:
|
||||
installed_version = version(pkg)
|
||||
@@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user