Files
bi_150-vllm/vllm/multimodal/media/audio.py

90 lines
2.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from io import BytesIO
from pathlib import Path
import numpy.typing as npt
import pybase64
import torch
from vllm.utils.import_utils import PlaceholderModule
from vllm.utils.serial_utils import tensor2base64
from .base import MediaIO
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(
self,
media: tuple[npt.NDArray, int],
*,
audio_format: str = "WAV",
) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format=audio_format)
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)