Files
sglang/python/sglang/srt/utils.py

526 lines
16 KiB
Python
Raw Normal View History

2024-05-12 04:54:07 -07:00
"""Common utilities."""
import base64
import fcntl
2024-05-16 18:07:30 -07:00
import logging
import os
import random
2024-07-18 23:28:40 -07:00
import resource
import socket
import struct
import time
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
2024-01-30 16:36:10 +00:00
from typing import List, Optional
import numpy as np
import psutil
import requests
import torch
2024-07-18 23:28:40 -07:00
import torch.distributed as dist
import triton
2024-05-12 04:54:07 -07:00
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
2024-05-11 20:55:00 -07:00
from starlette.middleware.base import BaseHTTPMiddleware
2024-05-16 18:07:30 -07:00
logger = logging.getLogger(__name__)
2024-05-12 20:49:04 -07:00
2024-04-09 23:27:31 +08:00
show_time_cost = False
time_infos = {}
2024-04-09 23:27:31 +08:00
def enable_show_time_cost():
global show_time_cost
show_time_cost = True
2024-04-09 23:27:31 +08:00
class TimeInfo:
def __init__(self, name, interval=0.1, color=0, indent=0):
self.name = name
self.interval = interval
self.color = color
self.indent = indent
2024-04-09 23:27:31 +08:00
self.acc_time = 0
self.last_acc_time = 0
2024-04-09 23:27:31 +08:00
def check(self):
if self.acc_time - self.last_acc_time > self.interval:
self.last_acc_time = self.acc_time
return True
return False
2024-04-09 23:27:31 +08:00
def pretty_print(self):
print(f"\x1b[{self.color}m", end="")
print("-" * self.indent * 2, end="")
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
2024-04-09 23:27:31 +08:00
def mark_start(name, interval=0.1, color=0, indent=0):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize()
2024-04-09 23:27:31 +08:00
if time_infos.get(name, None) is None:
time_infos[name] = TimeInfo(name, interval, color, indent)
time_infos[name].acc_time -= time.time()
2024-04-09 23:27:31 +08:00
def mark_end(name):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize()
2024-04-09 23:27:31 +08:00
time_infos[name].acc_time += time.time()
if time_infos[name].check():
time_infos[name].pretty_print()
def calculate_time(show=False, min_cost_ms=0.0):
def wrapper(func):
def inner_func(*args, **kwargs):
torch.cuda.synchronize()
if show:
start_time = time.time()
result = func(*args, **kwargs)
torch.cuda.synchronize()
if show:
cost_time = (time.time() - start_time) * 1000
if cost_time > min_cost_ms:
print(f"Function {func.__name__} took {cost_time} ms to run.")
return result
return inner_func
return wrapper
2024-05-24 03:48:53 -07:00
def get_available_gpu_memory(gpu_id, distributed=False):
2024-05-12 20:49:04 -07:00
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
2024-05-21 11:46:35 -07:00
torch.cuda.empty_cache()
2024-05-12 20:49:04 -07:00
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30)
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
2024-05-16 18:07:30 -07:00
def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
2024-03-03 17:09:16 +08:00
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("", port))
2024-05-16 18:07:30 -07:00
s.listen(1)
return True
except socket.error:
return False
2024-05-11 20:55:00 -07:00
def allocate_init_ports(
2024-01-30 16:36:10 +00:00
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
dp_size: int = 1,
2024-01-30 16:36:10 +00:00
):
"""Allocate ports for all connections."""
2024-05-16 18:07:30 -07:00
if additional_ports:
ret_ports = [port] + additional_ports
else:
ret_ports = [port]
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
2024-07-18 02:13:54 -07:00
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
2024-05-16 18:07:30 -07:00
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port is not None and ret_ports[0] != port:
logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
2024-01-30 16:36:10 +00:00
return ret_ports[0], ret_ports[1:num_ports_needed]
2024-01-30 16:36:10 +00:00
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
return logit_bias
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
2024-05-14 07:57:00 +08:00
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
2024-02-01 08:33:22 -08:00
model_path = model.path.lower()
2024-05-14 22:40:46 +08:00
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
2024-05-14 07:57:00 +08:00
raise ValueError("unrecognized type")
def decode_video_base64(video_base64):
from PIL import Image
# Decode the base64 string
video_bytes = base64.b64decode(video_base64)
# Placeholder for the start indices of each PNG image
img_starts = []
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
assert frame_format in [
"PNG",
"JPEG",
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if frame_format == "PNG":
# Find each PNG start signature to isolate images
i = 0
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if (
video_bytes[i] == 0x89
and video_bytes[i + 1] == 0x50
and video_bytes[i + 2] == 0x4E
and video_bytes[i + 3] == 0x47
and video_bytes[i + 4] == 0x0D
and video_bytes[i + 5] == 0x0A
and video_bytes[i + 6] == 0x1A
and video_bytes[i + 7] == 0x0A
):
img_starts.append(i)
i += 8 # Skip the PNG signature
else:
i += 1
else:
# Find each JPEG start (0xFFD8) to isolate images
i = 0
while (
i < len(video_bytes) - 1
): # Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
img_starts.append(i)
# Move to the next byte to continue searching for the next image start
i += 2
else:
i += 1
frames = []
for start_idx in img_starts:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx = (
img_starts[img_starts.index(start_idx) + 1]
if img_starts.index(start_idx) + 1 < len(img_starts)
else len(video_bytes)
)
img_bytes = video_bytes[start_idx:end_idx]
# Convert bytes to a PIL Image
img = Image.open(BytesIO(img_bytes))
# Convert PIL Image to a NumPy array
frame = np.array(img)
# Append the frame to the list of frames
frames.append(frame)
# Ensure there's at least one frame to avoid errors with np.stack
if frames:
return np.stack(frames, axis=0), img.size
else:
return np.array([]), (
0,
0,
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
from PIL import Image
2024-05-14 07:57:00 +08:00
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file)
elif image_file.startswith("data:"):
2024-01-16 15:41:30 -08:00
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
2024-05-14 07:57:00 +08:00
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
2024-05-14 07:57:00 +08:00
return image, image_size
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
2024-07-03 23:19:33 -07:00
def assert_pkg_version(pkg: str, min_version: str, message: str):
try:
installed_version = version(pkg)
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
raise Exception(
2024-07-03 23:19:33 -07:00
f"{pkg} is installed with version {installed_version}, which "
2024-07-05 10:06:17 -07:00
f"is less than the minimum required version {min_version}. " + message
)
except PackageNotFoundError:
2024-05-14 07:57:00 +08:00
raise Exception(
2024-07-05 10:06:17 -07:00
f"{pkg} with minimum required version {min_version} is not installed. "
+ message
2024-05-14 07:57:00 +08:00
)
2024-05-11 20:55:00 -07:00
def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
2024-07-20 12:43:11 +08:00
children = parent_process.children(recursive=True)
for child in children:
if child.pid != current_process.pid:
os.kill(child.pid, 9)
os.kill(parent_process.pid, 9)
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
Monkey patch the slow p2p access check in vllm.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
2024-07-06 23:34:10 -07:00
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
2024-07-08 17:46:55 -07:00
2024-07-06 23:34:10 -07:00
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
def monkey_patch_vllm_dummy_weight_loader():
"""
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""
from vllm.model_executor.model_loader.loader import (
2024-07-05 10:06:17 -07:00
CacheConfig,
DeviceConfig,
DummyModelLoader,
LoRAConfig,
ModelConfig,
2024-07-13 23:39:37 -07:00
MultiModalConfig,
2024-07-05 10:06:17 -07:00
ParallelConfig,
SchedulerConfig,
_initialize_model,
initialize_dummy_weights,
nn,
set_default_torch_dtype,
)
2024-07-05 10:06:17 -07:00
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
2024-07-05 10:06:17 -07:00
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
2024-07-05 10:06:17 -07:00
model = _initialize_model(
model_config,
self.load_config,
lora_config,
multimodal_config,
2024-07-05 10:06:17 -07:00
cache_config,
)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()
setattr(DummyModelLoader, "load_model", load_model)
2024-05-11 20:55:00 -07:00
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
def get_ip_address(ifname):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
2024-07-05 10:06:17 -07:00
struct.pack("256s", bytes(ifname[:15], "utf-8")),
)[20:24]
return socket.inet_ntoa(ip_address)
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
2024-07-05 10:06:17 -07:00
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
2024-07-05 10:06:17 -07:00
addrs_tensor = torch.tensor(
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
)
init_method = f"tcp://{server_args.nccl_init_addr}"
2024-07-05 10:06:17 -07:00
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
dist.send(addrs_tensor, dst=0)
2024-07-05 10:06:17 -07:00
print(
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
)
dist.barrier()
2024-07-05 10:06:17 -07:00
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
2024-07-05 10:06:17 -07:00
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
2024-07-05 10:06:17 -07:00
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
2024-07-05 10:06:17 -07:00
model_port_args.model_tp_ips[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = [ip] * num_tp_ports
model_port_args.model_tp_ports[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
2024-07-05 10:06:17 -07:00
dist.destroy_process_group()
2024-07-18 23:28:40 -07:00
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")