support llava video (#426)
This commit is contained in:
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
|
||||
continue
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
s.bind(("", port))
|
||||
s.listen(1) # Attempt to listen on the port
|
||||
port_list.append(port)
|
||||
except socket.error:
|
||||
pass
|
||||
pass # If any error occurs, this port is not usable
|
||||
|
||||
if len(port_list) == num:
|
||||
return port_list
|
||||
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
if isinstance(model, str):
|
||||
return "llava" in model or "yi-vl" in model
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, str):
|
||||
model = model.lower()
|
||||
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
model_path = model.path.lower()
|
||||
return "llava" in model_path or "yi-vl" in model_path
|
||||
raise Exception("unrecognized type")
|
||||
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
|
||||
raise ValueError("unrecognized type")
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
from PIL import Image
|
||||
|
||||
# Decode the base64 string
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
|
||||
# Placeholder for the start indices of each PNG image
|
||||
img_starts = []
|
||||
|
||||
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
||||
|
||||
assert frame_format in [
|
||||
"PNG",
|
||||
"JPEG",
|
||||
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
||||
|
||||
if frame_format == "PNG":
|
||||
# Find each PNG start signature to isolate images
|
||||
i = 0
|
||||
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
||||
# Check if we found the start of a PNG file
|
||||
if (
|
||||
video_bytes[i] == 0x89
|
||||
and video_bytes[i + 1] == 0x50
|
||||
and video_bytes[i + 2] == 0x4E
|
||||
and video_bytes[i + 3] == 0x47
|
||||
and video_bytes[i + 4] == 0x0D
|
||||
and video_bytes[i + 5] == 0x0A
|
||||
and video_bytes[i + 6] == 0x1A
|
||||
and video_bytes[i + 7] == 0x0A
|
||||
):
|
||||
img_starts.append(i)
|
||||
i += 8 # Skip the PNG signature
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
# Find each JPEG start (0xFFD8) to isolate images
|
||||
i = 0
|
||||
while (
|
||||
i < len(video_bytes) - 1
|
||||
): # Adjusted for the length of the JPEG SOI signature
|
||||
# Check if we found the start of a JPEG file
|
||||
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
||||
img_starts.append(i)
|
||||
# Move to the next byte to continue searching for the next image start
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
frames = []
|
||||
for start_idx in img_starts:
|
||||
# Assuming each image is back-to-back, the end of one image is the start of another
|
||||
# The last image goes until the end of the byte string
|
||||
end_idx = (
|
||||
img_starts[img_starts.index(start_idx) + 1]
|
||||
if img_starts.index(start_idx) + 1 < len(img_starts)
|
||||
else len(video_bytes)
|
||||
)
|
||||
img_bytes = video_bytes[start_idx:end_idx]
|
||||
|
||||
# Convert bytes to a PIL Image
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
|
||||
# Convert PIL Image to a NumPy array
|
||||
frame = np.array(img)
|
||||
|
||||
# Append the frame to the list of frames
|
||||
frames.append(frame)
|
||||
|
||||
# Ensure there's at least one frame to avoid errors with np.stack
|
||||
if frames:
|
||||
return np.stack(frames, axis=0), img.size
|
||||
else:
|
||||
return np.array([]), (
|
||||
0,
|
||||
0,
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
from PIL import Image
|
||||
|
||||
image = None
|
||||
image = image_size = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
@@ -289,10 +373,13 @@ def load_image(image_file):
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_file.split(",")[1]
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
return image, image_size
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str):
|
||||
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
|
||||
f"is less than the minimum required version {min_version}"
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
||||
raise Exception(
|
||||
f"{pkg} with minimum required version {min_version} is not installed"
|
||||
)
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
Reference in New Issue
Block a user