From a37e1247c183cff86a18f2ed1a075e40704b1c5e Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Tue, 8 Jul 2025 17:00:58 -0400 Subject: [PATCH] [Multimodal][Perf] Use `pybase64` instead of `base64` (#7724) --- python/pyproject.toml | 1 + python/sglang/bench_serving.py | 4 ++-- .../sglang/srt/entrypoints/http_server_engine.py | 2 +- python/sglang/srt/multimodal/mm_utils.py | 4 ++-- python/sglang/srt/utils.py | 16 +++++++++------- python/sglang/utils.py | 10 +++++----- test/srt/test_vision_openai_server_common.py | 1 - 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index e72488849..0d6d712a6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -38,6 +38,7 @@ runtime_common = [ "psutil", "pydantic", "pynvml", + "pybase64", "python-multipart", "pyzmq>=25.1.2", "soundfile==0.13.1", diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 04c2202d2..3ba4eae0f 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -814,9 +814,9 @@ def sample_mmmu_requests( List of tuples (prompt, prompt_token_len, output_token_len). """ try: - import base64 import io + import pybase64 from datasets import load_dataset except ImportError: raise ImportError("Please install datasets: pip install datasets") @@ -867,7 +867,7 @@ def sample_mmmu_requests( # Encode image to base64 buffered = io.BytesIO() image.save(buffered, format="JPEG") - img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8") image_data = f"data:image/jpeg;base64,{img_str}" else: continue diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index b2edf1abe..d1db80d65 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -1,4 +1,3 @@ -import base64 import copy import dataclasses import multiprocessing @@ -7,6 +6,7 @@ import threading import time from typing import Any, Dict, List, Optional, Tuple, Union +import pybase64 import requests import torch import torch.distributed as dist diff --git a/python/sglang/srt/multimodal/mm_utils.py b/python/sglang/srt/multimodal/mm_utils.py index 9c05c1859..c399be806 100644 --- a/python/sglang/srt/multimodal/mm_utils.py +++ b/python/sglang/srt/multimodal/mm_utils.py @@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326 """ import ast -import base64 import math import re from io import BytesIO import numpy as np +import pybase64 from PIL import Image from sglang.srt.utils import flatten_nested_list @@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints): def load_image_from_base64(image): - return Image.open(BytesIO(base64.b64decode(image))) + return Image.open(BytesIO(pybase64.b64decode(image, validate=True))) def expand2square(pil_img, background_color): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 052e7328f..bc2affa1a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -15,7 +15,6 @@ from __future__ import annotations -import base64 import builtins import ctypes import dataclasses @@ -68,6 +67,7 @@ from typing import ( import numpy as np import psutil +import pybase64 import requests import torch import torch.distributed @@ -616,7 +616,7 @@ def decode_video_base64(video_base64): from PIL import Image # Decode the base64 string - video_bytes = base64.b64decode(video_base64) + video_bytes = pybase64.b64decode(video_base64, validate=True) # Placeholder for the start indices of each PNG image img_starts = [] @@ -702,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra audio, original_sr = sf.read(BytesIO(audio_file)) elif audio_file.startswith("data:"): audio_file = audio_file.split(",")[1] - audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file))) + audio, original_sr = sf.read( + BytesIO(pybase64.b64decode(audio_file, validate=True)) + ) elif audio_file.startswith("http://") or audio_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "5")) response = requests.get(audio_file, stream=True, timeout=timeout) @@ -771,12 +773,12 @@ def load_image( image = Image.open(image_file) elif image_file.startswith("data:"): image_file = image_file.split(",")[1] - image = Image.open(BytesIO(base64.b64decode(image_file))) + image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) elif isinstance(image_file, str): - image = Image.open(BytesIO(base64.b64decode(image_file))) + image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) else: raise ValueError(f"Invalid image: {image}") @@ -1866,7 +1868,7 @@ class MultiprocessingSerializer: if output_str: # Convert bytes to base64-encoded string - output = base64.b64encode(output).decode("utf-8") + output = pybase64.b64encode(output).decode("utf-8") return output @@ -1883,7 +1885,7 @@ class MultiprocessingSerializer: """ if isinstance(data, str): # Decode base64 string to bytes - data = base64.b64decode(data) + data = pybase64.b64decode(data, validate=True) return ForkingPickler.loads(data) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 6b3f36e19..83c653232 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,6 +1,5 @@ """Common utilities""" -import base64 import importlib import json import logging @@ -20,6 +19,7 @@ from json import dumps from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np +import pybase64 import requests from IPython.display import HTML, display from pydantic import BaseModel @@ -148,15 +148,15 @@ def encode_image_base64(image_path: Union[str, bytes]): if isinstance(image_path, str): with open(image_path, "rb") as image_file: data = image_file.read() - return base64.b64encode(data).decode("utf-8") + return pybase64.b64encode(data).decode("utf-8") elif isinstance(image_path, bytes): - return base64.b64encode(image_path).decode("utf-8") + return pybase64.b64encode(image_path).decode("utf-8") else: # image_path is PIL.WebPImagePlugin.WebPImageFile image = image_path buffered = BytesIO() image.save(buffered, format="PNG") - return base64.b64encode(buffered.getvalue()).decode("utf-8") + return pybase64.b64encode(buffered.getvalue()).decode("utf-8") def encode_frame(frame): @@ -223,7 +223,7 @@ def encode_video_base64(video_path: str, num_frames: int = 16): video_bytes = b"".join(encoded_frames) # Encode the concatenated bytes to base64 - video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8") + video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8") return video_base64 diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 3687d9381..42b8e889d 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -1,5 +1,4 @@ import base64 -import copy import io import json import os