vlm: enforce pybase64 for image and str encode/decode (#10700)
This commit is contained in:
@@ -31,7 +31,10 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import pybase64
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@@ -1020,14 +1023,6 @@ def sample_mmmu_requests(
|
||||
Returns:
|
||||
List of tuples (prompt, prompt_token_len, output_token_len).
|
||||
"""
|
||||
try:
|
||||
import io
|
||||
|
||||
import pybase64
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
raise ImportError("Please install datasets: pip install datasets")
|
||||
|
||||
print("Loading MMMU dataset from HuggingFace...")
|
||||
|
||||
try:
|
||||
@@ -1396,13 +1391,6 @@ def sample_image_requests(
|
||||
- Text lengths follow the 'random' dataset sampling rule. ``prompt_len``
|
||||
only counts text tokens and excludes image data.
|
||||
"""
|
||||
try:
|
||||
import pybase64
|
||||
from PIL import Image
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install Pillow to generate random images: pip install pillow"
|
||||
) from e
|
||||
|
||||
# Parse resolution (supports presets and 'heightxwidth')
|
||||
width, height = parse_image_resolution(image_resolution)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import base64
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import MultiprocessingSerializer
|
||||
@@ -77,14 +77,16 @@ class NaiveDistributed:
|
||||
)
|
||||
|
||||
_get_path(self._rank).write_text(
|
||||
base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
|
||||
pybase64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
|
||||
)
|
||||
|
||||
def _read_one(interesting_rank: int):
|
||||
p = _get_path(interesting_rank)
|
||||
while True:
|
||||
if p.exists() and (text := p.read_text()).endswith(text_postfix):
|
||||
return pickle.loads(base64.b64decode(text[: -len(text_postfix)]))
|
||||
return pickle.loads(
|
||||
pybase64.b64decode(text[: -len(text_postfix)], validate=True)
|
||||
)
|
||||
time.sleep(0.001)
|
||||
|
||||
return [
|
||||
|
||||
@@ -872,9 +872,9 @@ def get_image_bytes(image_file: Union[str, bytes]):
|
||||
return f.read()
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_file.split(",")[1]
|
||||
return pybase64.b64decode(image_file)
|
||||
return pybase64.b64decode(image_file, validate=True)
|
||||
elif isinstance(image_file, str):
|
||||
return pybase64.b64decode(image_file)
|
||||
return pybase64.b64decode(image_file, validate=True)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid image: {image_file}")
|
||||
|
||||
@@ -911,7 +911,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif video_file.startswith("data:"):
|
||||
_, encoded = video_file.split(",", 1)
|
||||
video_bytes = pybase64.b64decode(encoded)
|
||||
video_bytes = pybase64.b64decode(encoded, validate=True)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
@@ -919,7 +919,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
||||
elif os.path.isfile(video_file):
|
||||
vr = VideoReader(video_file, ctx=ctx)
|
||||
else:
|
||||
video_bytes = pybase64.b64decode(video_file)
|
||||
video_bytes = pybase64.b64decode(video_file, validate=True)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
@@ -2083,7 +2083,7 @@ class MultiprocessingSerializer:
|
||||
|
||||
if output_str:
|
||||
# Convert bytes to base64-encoded string
|
||||
output = pybase64.b64encode(output).decode("utf-8")
|
||||
pybase64.b64encode(output).decode("utf-8")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user