support llava video (#426)

This commit is contained in:
Yuanhan Zhang
2024-05-14 07:57:00 +08:00
committed by GitHub
parent 5dc55a5f02
commit 0992d85f92
37 changed files with 1139 additions and 222 deletions

View File

@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("", port))
s.listen(1) # Attempt to listen on the port
port_list.append(port)
except socket.error:
pass
pass # If any error occurs, this port is not usable
if len(port_list) == num:
return port_list
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
def is_multimodal_model(model):
if isinstance(model, str):
return "llava" in model or "yi-vl" in model
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path
raise Exception("unrecognized type")
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
raise ValueError("unrecognized type")
def decode_video_base64(video_base64):
from PIL import Image
# Decode the base64 string
video_bytes = base64.b64decode(video_base64)
# Placeholder for the start indices of each PNG image
img_starts = []
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
assert frame_format in [
"PNG",
"JPEG",
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if frame_format == "PNG":
# Find each PNG start signature to isolate images
i = 0
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if (
video_bytes[i] == 0x89
and video_bytes[i + 1] == 0x50
and video_bytes[i + 2] == 0x4E
and video_bytes[i + 3] == 0x47
and video_bytes[i + 4] == 0x0D
and video_bytes[i + 5] == 0x0A
and video_bytes[i + 6] == 0x1A
and video_bytes[i + 7] == 0x0A
):
img_starts.append(i)
i += 8 # Skip the PNG signature
else:
i += 1
else:
# Find each JPEG start (0xFFD8) to isolate images
i = 0
while (
i < len(video_bytes) - 1
): # Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
img_starts.append(i)
# Move to the next byte to continue searching for the next image start
i += 2
else:
i += 1
frames = []
for start_idx in img_starts:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx = (
img_starts[img_starts.index(start_idx) + 1]
if img_starts.index(start_idx) + 1 < len(img_starts)
else len(video_bytes)
)
img_bytes = video_bytes[start_idx:end_idx]
# Convert bytes to a PIL Image
img = Image.open(BytesIO(img_bytes))
# Convert PIL Image to a NumPy array
frame = np.array(img)
# Append the frame to the list of frames
frames.append(frame)
# Ensure there's at least one frame to avoid errors with np.stack
if frames:
return np.stack(frames, axis=0), img.size
else:
return np.array([]), (
0,
0,
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
from PIL import Image
image = None
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -289,10 +373,13 @@ def load_image(image_file):
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
return image, image_size
def assert_pkg_version(pkg: str, min_version: str):
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
f"is less than the minimum required version {min_version}"
)
except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
raise Exception(
f"{pkg} with minimum required version {min_version} is not installed"
)
API_KEY_HEADER_NAME = "X-API-Key"