2024-05-12 04:54:07 -07:00
|
|
|
"""Common utilities."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
import base64
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import socket
|
|
|
|
|
import time
|
2024-05-12 07:37:49 +08:00
|
|
|
from importlib.metadata import PackageNotFoundError, version
|
2024-01-08 04:37:50 +00:00
|
|
|
from io import BytesIO
|
2024-01-30 16:36:10 +00:00
|
|
|
from typing import List, Optional
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
import numpy as np
|
2024-05-11 20:55:00 -07:00
|
|
|
import pydantic
|
2024-01-08 04:37:50 +00:00
|
|
|
import requests
|
|
|
|
|
import torch
|
2024-05-12 04:54:07 -07:00
|
|
|
from fastapi.responses import JSONResponse
|
2024-05-12 07:37:49 +08:00
|
|
|
from packaging import version as pkg_version
|
2024-05-11 20:55:00 -07:00
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-05-12 20:49:04 -07:00
|
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
show_time_cost = False
|
|
|
|
|
time_infos = {}
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
def enable_show_time_cost():
|
|
|
|
|
global show_time_cost
|
|
|
|
|
show_time_cost = True
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
class TimeInfo:
|
|
|
|
|
def __init__(self, name, interval=0.1, color=0, indent=0):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.interval = interval
|
|
|
|
|
self.color = color
|
|
|
|
|
self.indent = indent
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
self.acc_time = 0
|
|
|
|
|
self.last_acc_time = 0
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
def check(self):
|
|
|
|
|
if self.acc_time - self.last_acc_time > self.interval:
|
|
|
|
|
self.last_acc_time = self.acc_time
|
|
|
|
|
return True
|
|
|
|
|
return False
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
def pretty_print(self):
|
|
|
|
|
print(f"\x1b[{self.color}m", end="")
|
|
|
|
|
print("-" * self.indent * 2, end="")
|
|
|
|
|
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
def mark_start(name, interval=0.1, color=0, indent=0):
|
|
|
|
|
global time_infos, show_time_cost
|
|
|
|
|
if not show_time_cost:
|
|
|
|
|
return
|
2024-01-08 04:37:50 +00:00
|
|
|
torch.cuda.synchronize()
|
2024-04-09 23:27:31 +08:00
|
|
|
if time_infos.get(name, None) is None:
|
|
|
|
|
time_infos[name] = TimeInfo(name, interval, color, indent)
|
|
|
|
|
time_infos[name].acc_time -= time.time()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
2024-04-09 23:27:31 +08:00
|
|
|
def mark_end(name):
|
|
|
|
|
global time_infos, show_time_cost
|
|
|
|
|
if not show_time_cost:
|
|
|
|
|
return
|
2024-01-08 04:37:50 +00:00
|
|
|
torch.cuda.synchronize()
|
2024-04-09 23:27:31 +08:00
|
|
|
time_infos[name].acc_time += time.time()
|
|
|
|
|
if time_infos[name].check():
|
|
|
|
|
time_infos[name].pretty_print()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_time(show=False, min_cost_ms=0.0):
|
|
|
|
|
def wrapper(func):
|
|
|
|
|
def inner_func(*args, **kwargs):
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
if show:
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
if show:
|
|
|
|
|
cost_time = (time.time() - start_time) * 1000
|
|
|
|
|
if cost_time > min_cost_ms:
|
|
|
|
|
print(f"Function {func.__name__} took {cost_time} ms to run.")
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
return inner_func
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
2024-05-12 20:49:04 -07:00
|
|
|
def get_available_gpu_memory(gpu_id, distributed=True):
|
|
|
|
|
"""
|
|
|
|
|
Get available memory for cuda:gpu_id device.
|
|
|
|
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
|
|
|
|
"""
|
|
|
|
|
num_gpus = torch.cuda.device_count()
|
|
|
|
|
assert gpu_id < num_gpus
|
|
|
|
|
|
|
|
|
|
if torch.cuda.current_device() != gpu_id:
|
|
|
|
|
print(
|
|
|
|
|
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
|
|
|
|
"which may cause useless memory allocation for torch CUDA context.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
|
|
|
|
|
|
|
|
|
if distributed:
|
|
|
|
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
|
|
|
|
torch.device("cuda", gpu_id)
|
|
|
|
|
)
|
|
|
|
|
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
|
|
|
|
free_gpu_memory = tensor.item()
|
|
|
|
|
|
|
|
|
|
return free_gpu_memory / (1 << 30)
|
|
|
|
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def set_random_seed(seed: int) -> None:
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def alloc_usable_network_port(num, used_list=()):
|
|
|
|
|
port_list = []
|
|
|
|
|
for port in range(10000, 65536):
|
|
|
|
|
if port in used_list:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
2024-05-14 07:57:00 +08:00
|
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
2024-01-08 04:37:50 +00:00
|
|
|
try:
|
|
|
|
|
s.bind(("", port))
|
2024-05-14 07:57:00 +08:00
|
|
|
s.listen(1) # Attempt to listen on the port
|
2024-01-08 04:37:50 +00:00
|
|
|
port_list.append(port)
|
|
|
|
|
except socket.error:
|
2024-05-14 07:57:00 +08:00
|
|
|
pass # If any error occurs, this port is not usable
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if len(port_list) == num:
|
|
|
|
|
return port_list
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
2024-01-30 08:34:51 -08:00
|
|
|
def check_port(port):
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
|
|
|
try:
|
2024-03-03 17:09:16 +08:00
|
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
2024-01-30 08:34:51 -08:00
|
|
|
s.bind(("", port))
|
|
|
|
|
return True
|
|
|
|
|
except socket.error:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2024-05-11 20:55:00 -07:00
|
|
|
def allocate_init_ports(
|
2024-01-30 16:36:10 +00:00
|
|
|
port: Optional[int] = None,
|
|
|
|
|
additional_ports: Optional[List[int]] = None,
|
|
|
|
|
tp_size: int = 1,
|
|
|
|
|
):
|
2024-01-30 08:34:51 -08:00
|
|
|
port = 30000 if port is None else port
|
|
|
|
|
additional_ports = [] if additional_ports is None else additional_ports
|
2024-01-30 16:36:10 +00:00
|
|
|
additional_ports = (
|
|
|
|
|
[additional_ports] if isinstance(additional_ports, int) else additional_ports
|
|
|
|
|
)
|
2024-01-30 08:34:51 -08:00
|
|
|
# first check on server port
|
|
|
|
|
if not check_port(port):
|
|
|
|
|
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
2024-02-11 05:50:13 -08:00
|
|
|
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
2024-01-30 08:34:51 -08:00
|
|
|
port = new_port
|
|
|
|
|
|
|
|
|
|
# then we check on additional ports
|
|
|
|
|
additional_unique_ports = set(additional_ports) - {port}
|
|
|
|
|
# filter out ports that are already in use
|
|
|
|
|
can_use_ports = [port for port in additional_unique_ports if check_port(port)]
|
|
|
|
|
|
|
|
|
|
num_specified_ports = len(can_use_ports)
|
|
|
|
|
if num_specified_ports < 4 + tp_size:
|
|
|
|
|
addtional_can_use_ports = alloc_usable_network_port(
|
|
|
|
|
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
|
|
|
|
)
|
|
|
|
|
can_use_ports.extend(addtional_can_use_ports)
|
2024-01-30 16:36:10 +00:00
|
|
|
|
|
|
|
|
additional_ports = can_use_ports[: 4 + tp_size]
|
2024-01-30 08:34:51 -08:00
|
|
|
return port, additional_ports
|
|
|
|
|
|
2024-01-30 16:36:10 +00:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
2024-01-23 12:14:51 +08:00
|
|
|
# a bug when model's vocab size > tokenizer.vocab_size
|
|
|
|
|
vocab_size = tokenizer.vocab_size
|
2024-01-08 04:37:50 +00:00
|
|
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
|
|
|
|
for t_id in range(vocab_size):
|
2024-01-23 12:14:51 +08:00
|
|
|
ss = tokenizer.decode([t_id]).strip()
|
2024-01-08 04:37:50 +00:00
|
|
|
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
|
|
|
|
logit_bias[t_id] = -1e5
|
|
|
|
|
|
|
|
|
|
return logit_bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_kernel_launcher(kernel):
|
|
|
|
|
"""A faster launcher for triton kernels."""
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
else:
|
|
|
|
|
rank = 0
|
|
|
|
|
|
|
|
|
|
kernels = kernel.cache[rank].values()
|
|
|
|
|
kernel = next(iter(kernels))
|
|
|
|
|
|
|
|
|
|
# Different trition versions use different low-level names
|
|
|
|
|
if hasattr(kernel, "cu_function"):
|
|
|
|
|
kfunction = kernel.cu_function
|
|
|
|
|
else:
|
|
|
|
|
kfunction = kernel.function
|
|
|
|
|
|
|
|
|
|
if hasattr(kernel, "c_wrapper"):
|
|
|
|
|
run = kernel.c_wrapper
|
|
|
|
|
else:
|
|
|
|
|
run = kernel.run
|
|
|
|
|
|
|
|
|
|
add_cluster_dim = True
|
|
|
|
|
|
|
|
|
|
def ret_func(grid, num_warps, *args):
|
|
|
|
|
nonlocal add_cluster_dim
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if add_cluster_dim:
|
|
|
|
|
run(
|
|
|
|
|
grid[0],
|
|
|
|
|
grid[1],
|
|
|
|
|
grid[2],
|
|
|
|
|
num_warps,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
kernel.shared,
|
|
|
|
|
0,
|
|
|
|
|
kfunction,
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
kernel,
|
|
|
|
|
*args,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
run(
|
|
|
|
|
grid[0],
|
|
|
|
|
grid[1],
|
|
|
|
|
grid[2],
|
|
|
|
|
num_warps,
|
|
|
|
|
kernel.shared,
|
|
|
|
|
0,
|
|
|
|
|
kfunction,
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
kernel,
|
|
|
|
|
*args,
|
|
|
|
|
)
|
|
|
|
|
except TypeError:
|
|
|
|
|
add_cluster_dim = not add_cluster_dim
|
|
|
|
|
ret_func(grid, num_warps, *args)
|
|
|
|
|
|
|
|
|
|
return ret_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_multimodal_model(model):
|
|
|
|
|
from sglang.srt.model_config import ModelConfig
|
|
|
|
|
|
2024-05-14 07:57:00 +08:00
|
|
|
if isinstance(model, str):
|
|
|
|
|
model = model.lower()
|
|
|
|
|
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
if isinstance(model, ModelConfig):
|
2024-02-01 08:33:22 -08:00
|
|
|
model_path = model.path.lower()
|
2024-05-14 07:57:00 +08:00
|
|
|
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
|
|
|
|
|
|
|
|
|
raise ValueError("unrecognized type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_video_base64(video_base64):
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
# Decode the base64 string
|
|
|
|
|
video_bytes = base64.b64decode(video_base64)
|
|
|
|
|
|
|
|
|
|
# Placeholder for the start indices of each PNG image
|
|
|
|
|
img_starts = []
|
|
|
|
|
|
|
|
|
|
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
|
|
|
|
|
|
|
|
|
assert frame_format in [
|
|
|
|
|
"PNG",
|
|
|
|
|
"JPEG",
|
|
|
|
|
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
|
|
|
|
|
|
|
|
|
if frame_format == "PNG":
|
|
|
|
|
# Find each PNG start signature to isolate images
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
|
|
|
|
# Check if we found the start of a PNG file
|
|
|
|
|
if (
|
|
|
|
|
video_bytes[i] == 0x89
|
|
|
|
|
and video_bytes[i + 1] == 0x50
|
|
|
|
|
and video_bytes[i + 2] == 0x4E
|
|
|
|
|
and video_bytes[i + 3] == 0x47
|
|
|
|
|
and video_bytes[i + 4] == 0x0D
|
|
|
|
|
and video_bytes[i + 5] == 0x0A
|
|
|
|
|
and video_bytes[i + 6] == 0x1A
|
|
|
|
|
and video_bytes[i + 7] == 0x0A
|
|
|
|
|
):
|
|
|
|
|
img_starts.append(i)
|
|
|
|
|
i += 8 # Skip the PNG signature
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
else:
|
|
|
|
|
# Find each JPEG start (0xFFD8) to isolate images
|
|
|
|
|
i = 0
|
|
|
|
|
while (
|
|
|
|
|
i < len(video_bytes) - 1
|
|
|
|
|
): # Adjusted for the length of the JPEG SOI signature
|
|
|
|
|
# Check if we found the start of a JPEG file
|
|
|
|
|
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
|
|
|
|
img_starts.append(i)
|
|
|
|
|
# Move to the next byte to continue searching for the next image start
|
|
|
|
|
i += 2
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
frames = []
|
|
|
|
|
for start_idx in img_starts:
|
|
|
|
|
# Assuming each image is back-to-back, the end of one image is the start of another
|
|
|
|
|
# The last image goes until the end of the byte string
|
|
|
|
|
end_idx = (
|
|
|
|
|
img_starts[img_starts.index(start_idx) + 1]
|
|
|
|
|
if img_starts.index(start_idx) + 1 < len(img_starts)
|
|
|
|
|
else len(video_bytes)
|
|
|
|
|
)
|
|
|
|
|
img_bytes = video_bytes[start_idx:end_idx]
|
|
|
|
|
|
|
|
|
|
# Convert bytes to a PIL Image
|
|
|
|
|
img = Image.open(BytesIO(img_bytes))
|
|
|
|
|
|
|
|
|
|
# Convert PIL Image to a NumPy array
|
|
|
|
|
frame = np.array(img)
|
|
|
|
|
|
|
|
|
|
# Append the frame to the list of frames
|
|
|
|
|
frames.append(frame)
|
|
|
|
|
|
|
|
|
|
# Ensure there's at least one frame to avoid errors with np.stack
|
|
|
|
|
if frames:
|
|
|
|
|
return np.stack(frames, axis=0), img.size
|
|
|
|
|
else:
|
|
|
|
|
return np.array([]), (
|
|
|
|
|
0,
|
|
|
|
|
0,
|
|
|
|
|
) # Return an empty array and size tuple if no frames were found
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_image(image_file):
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
2024-05-14 07:57:00 +08:00
|
|
|
image = image_size = None
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
|
|
|
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
|
|
|
|
response = requests.get(image_file, timeout=timeout)
|
|
|
|
|
image = Image.open(BytesIO(response.content))
|
|
|
|
|
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
|
|
|
|
image = Image.open(image_file)
|
|
|
|
|
elif image_file.startswith("data:"):
|
2024-01-16 15:41:30 -08:00
|
|
|
image_file = image_file.split(",")[1]
|
2024-01-08 04:37:50 +00:00
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
2024-05-14 07:57:00 +08:00
|
|
|
elif image_file.startswith("video:"):
|
|
|
|
|
image_file = image_file.replace("video:", "")
|
|
|
|
|
image, image_size = decode_video_base64(image_file)
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
|
|
|
|
|
2024-05-14 07:57:00 +08:00
|
|
|
return image, image_size
|
2024-05-12 07:37:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_pkg_version(pkg: str, min_version: str):
|
|
|
|
|
try:
|
|
|
|
|
installed_version = version(pkg)
|
|
|
|
|
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
|
|
|
|
raise Exception(
|
|
|
|
|
f"{pkg} is installed with version {installed_version} which "
|
|
|
|
|
f"is less than the minimum required version {min_version}"
|
|
|
|
|
)
|
|
|
|
|
except PackageNotFoundError:
|
2024-05-14 07:57:00 +08:00
|
|
|
raise Exception(
|
|
|
|
|
f"{pkg} with minimum required version {min_version} is not installed"
|
|
|
|
|
)
|
2024-05-11 20:55:00 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
API_KEY_HEADER_NAME = "X-API-Key"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
def __init__(self, app, api_key: str):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.api_key = api_key
|
|
|
|
|
|
|
|
|
|
async def dispatch(self, request, call_next):
|
|
|
|
|
# extract API key from the request headers
|
|
|
|
|
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
|
|
|
|
if not api_key_header or api_key_header != self.api_key:
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=403,
|
|
|
|
|
content={"detail": "Invalid API Key"},
|
|
|
|
|
)
|
|
|
|
|
response = await call_next(request)
|
|
|
|
|
return response
|
|
|
|
|
|
2024-05-12 04:54:07 -07:00
|
|
|
|
2024-05-11 20:55:00 -07:00
|
|
|
# FIXME: Remove this once we drop support for pydantic 1.x
|
|
|
|
|
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def jsonify_pydantic_model(obj: BaseModel):
|
|
|
|
|
if IS_PYDANTIC_1:
|
|
|
|
|
return obj.json(ensure_ascii=False)
|
2024-05-12 20:49:04 -07:00
|
|
|
return obj.model_dump_json()
|