78 lines
2.0 KiB
Python
78 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import base64
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
from vllm.inputs.registry import InputContext
|
|
from vllm.utils import PlaceholderModule
|
|
|
|
from .base import MediaIO, MultiModalPlugin
|
|
from .inputs import AudioItem, ModalityData, MultiModalKwargs
|
|
|
|
try:
|
|
import librosa
|
|
except ImportError:
|
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
|
|
|
try:
|
|
import soundfile
|
|
except ImportError:
|
|
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
|
|
|
|
|
class AudioPlugin(MultiModalPlugin):
|
|
"""Plugin for audio data."""
|
|
|
|
def get_data_key(self) -> str:
|
|
return "audio"
|
|
|
|
def _default_input_mapper(
|
|
self,
|
|
ctx: InputContext,
|
|
data: ModalityData[AudioItem],
|
|
**mm_processor_kwargs,
|
|
) -> MultiModalKwargs:
|
|
raise NotImplementedError("There is no default audio input mapper")
|
|
|
|
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
|
raise NotImplementedError(
|
|
"There is no default maximum multimodal tokens")
|
|
|
|
|
|
def resample_audio(
|
|
audio: npt.NDArray[np.floating],
|
|
*,
|
|
orig_sr: float,
|
|
target_sr: float,
|
|
) -> npt.NDArray[np.floating]:
|
|
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
|
|
|
|
|
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
|
|
|
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
|
return librosa.load(BytesIO(data), sr=None)
|
|
|
|
def load_base64(
|
|
self,
|
|
media_type: str,
|
|
data: str,
|
|
) -> tuple[npt.NDArray, float]:
|
|
return self.load_bytes(base64.b64decode(data))
|
|
|
|
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
|
return librosa.load(filepath, sr=None)
|
|
|
|
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
|
|
audio, sr = media
|
|
|
|
with BytesIO() as buffer:
|
|
soundfile.write(buffer, audio, sr, format="WAV")
|
|
data = buffer.getvalue()
|
|
|
|
return base64.b64encode(data).decode('utf-8')
|