Sync from v0.13
This commit is contained in:
147
vllm/multimodal/audio.py
Normal file
147
vllm/multimodal/audio.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||
import scipy.signal
|
||||
|
||||
if orig_sr > target_sr:
|
||||
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: float | None = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError(
|
||||
"Audio resampling is not supported when `target_sr` is not provided"
|
||||
)
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'."
|
||||
)
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return tensor2base64(media)
|
||||
Reference in New Issue
Block a user