# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 from io import BytesIO from pathlib import Path import numpy.typing as npt import pybase64 import torch from vllm.utils.import_utils import PlaceholderModule from vllm.utils.serial_utils import tensor2base64 from .base import MediaIO try: import librosa except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] try: import soundfile except ImportError: soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def __init__(self, **kwargs) -> None: super().__init__() # `kwargs` contains custom arguments from # --media-io-kwargs for this modality. # They can be passed to the underlying # media loaders (e.g. custom implementations) # for flexible control. self.kwargs = kwargs def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: return librosa.load(BytesIO(data), sr=None) def load_base64( self, media_type: str, data: str, ) -> tuple[npt.NDArray, float]: return self.load_bytes(base64.b64decode(data)) def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: return librosa.load(filepath, sr=None) def encode_base64( self, media: tuple[npt.NDArray, int], *, audio_format: str = "WAV", ) -> str: audio, sr = media with BytesIO() as buffer: soundfile.write(buffer, audio, sr, format=audio_format) data = buffer.getvalue() return base64.b64encode(data).decode("utf-8") class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): def __init__(self) -> None: super().__init__() def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) # Enable sparse tensor integrity checks to prevent out-of-bounds # writes from maliciously crafted tensors with torch.sparse.check_sparse_tensor_invariants(): tensor = torch.load(buffer, weights_only=True) return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: # Enable sparse tensor integrity checks to prevent out-of-bounds # writes from maliciously crafted tensors with torch.sparse.check_sparse_tensor_invariants(): tensor = torch.load(filepath, weights_only=True) return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return tensor2base64(media)