Files
2026-01-19 10:38:50 +08:00

148 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from io import BytesIO
from pathlib import Path
from typing import Literal
import numpy as np
import numpy.typing as npt
import pybase64
import torch
from vllm.utils.import_utils import PlaceholderModule
from vllm.utils.serial_utils import tensor2base64
from .base import MediaIO
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
def resample_audio_librosa(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
def resample_audio_scipy(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
):
# lazy import scipy.signal, otherwise it will crash doc build.
import scipy.signal
if orig_sr > target_sr:
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
elif orig_sr < target_sr:
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
return audio
class AudioResampler:
"""Resample audio data to a target sample rate."""
def __init__(
self,
target_sr: float | None = None,
method: Literal["librosa", "scipy"] = "librosa",
):
self.target_sr = target_sr
self.method = method
def resample(
self,
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
) -> npt.NDArray[np.floating]:
if self.target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when `target_sr` is not provided"
)
if self.method == "librosa":
return resample_audio_librosa(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
elif self.method == "scipy":
return resample_audio_scipy(
audio, orig_sr=orig_sr, target_sr=self.target_sr
)
else:
raise ValueError(
f"Invalid resampling method: {self.method}. "
"Supported methods are 'librosa' and 'scipy'."
)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(buffer, weights_only=True)
return tensor.to_dense()
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(filepath, weights_only=True)
return tensor.to_dense()
def encode_base64(self, media: torch.Tensor) -> str:
return tensor2base64(media)