Add Python API and Python examples for audio tagging (#753)
This commit is contained in:
9
.github/scripts/test-python.sh
vendored
9
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,15 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test audio tagging"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
python3 ./python-api-examples/audio-tagging-from-a-file.py
|
||||
rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||
|
||||
|
||||
log "test streaming zipformer2 ctc HLG decoding"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
|
||||
@@ -17,7 +17,6 @@ fi
|
||||
if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
|
||||
mkdir -p $onnxruntime_dir
|
||||
pushd $onnxruntime_dir
|
||||
# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||
wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||
tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||
rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||
|
||||
@@ -3,7 +3,6 @@ function(download_kaldi_native_fbank)
|
||||
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
||||
# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
|
||||
|
||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
121
python-api-examples/audio-tagging-from-a-file.py
Executable file
121
python-api-examples/audio-tagging-from-a-file.py
Executable file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script shows how to use audio tagging Python APIs to tag a file.
|
||||
|
||||
Please read the code to download the required model files and test wave file.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def read_test_wave():
|
||||
# Please download the model files and test wave files from
|
||||
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||
test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"
|
||||
|
||||
if not Path(test_wave).is_file():
|
||||
raise ValueError(
|
||||
f"Please download {test_wave} from "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||
)
|
||||
|
||||
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
|
||||
data, sample_rate = sf.read(
|
||||
test_wave,
|
||||
always_2d=True,
|
||||
dtype="float32",
|
||||
)
|
||||
data = data[:, 0] # use only the first channel
|
||||
samples = np.ascontiguousarray(data)
|
||||
|
||||
# samples is a 1-d array of dtype float32
|
||||
# sample_rate is a scalar
|
||||
return samples, sample_rate
|
||||
|
||||
|
||||
def create_audio_tagger():
|
||||
# Please download the model files and test wave files from
|
||||
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||
model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
|
||||
label_file = (
|
||||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
|
||||
)
|
||||
|
||||
if not Path(model_file).is_file():
|
||||
raise ValueError(
|
||||
f"Please download {model_file} from "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||
)
|
||||
|
||||
if not Path(label_file).is_file():
|
||||
raise ValueError(
|
||||
f"Please download {label_file} from "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||
)
|
||||
|
||||
config = sherpa_onnx.AudioTaggingConfig(
|
||||
model=sherpa_onnx.AudioTaggingModelConfig(
|
||||
zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
|
||||
model=model_file,
|
||||
),
|
||||
num_threads=1,
|
||||
debug=True,
|
||||
provider="cpu",
|
||||
),
|
||||
labels=label_file,
|
||||
top_k=5,
|
||||
)
|
||||
if not config.validate():
|
||||
raise ValueError(f"Please check the config: {config}")
|
||||
|
||||
print(config)
|
||||
|
||||
return sherpa_onnx.AudioTagging(config)
|
||||
|
||||
|
||||
def main():
|
||||
logging.info("Create audio tagger")
|
||||
audio_tagger = create_audio_tagger()
|
||||
|
||||
logging.info("Read test wave")
|
||||
samples, sample_rate = read_test_wave()
|
||||
|
||||
logging.info("Computing")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
stream = audio_tagger.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||
result = audio_tagger.compute(stream)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
audio_duration = len(samples) / sample_rate
|
||||
|
||||
real_time_factor = elapsed_seconds / audio_duration
|
||||
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||
logging.info(
|
||||
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
|
||||
)
|
||||
|
||||
s = "\n"
|
||||
for i, e in enumerate(result):
|
||||
s += f"{i}: {e}\n"
|
||||
|
||||
logging.info(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
@@ -1,6 +1,7 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
set(srcs
|
||||
audio-tagging.cc
|
||||
circular-buffer.cc
|
||||
display.cc
|
||||
endpoint.cc
|
||||
|
||||
87
sherpa-onnx/python/csrc/audio-tagging.cc
Normal file
87
sherpa-onnx/python/csrc/audio-tagging.cc
Normal file
@@ -0,0 +1,87 @@
|
||||
// sherpa-onnx/python/csrc/audio-tagging.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/audio-tagging.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) {
|
||||
using PyClass = OfflineZipformerAudioTaggingModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindAudioTaggingModelConfig(py::module *m) {
|
||||
PybindOfflineZipformerAudioTaggingModelConfig(m);
|
||||
|
||||
using PyClass = AudioTaggingModelConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
|
||||
bool, const std::string &>(),
|
||||
py::arg("zipformer"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("zipformer", &PyClass::zipformer)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindAudioTaggingConfig(py::module *m) {
|
||||
PybindAudioTaggingModelConfig(m);
|
||||
|
||||
using PyClass = AudioTaggingConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "AudioTaggingConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const AudioTaggingModelConfig &, const std::string &,
|
||||
int32_t>(),
|
||||
py::arg("model"), py::arg("labels"), py::arg("top_k") = 5)
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("labels", &PyClass::labels)
|
||||
.def_readwrite("top_k", &PyClass::top_k)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindAudioEvent(py::module *m) {
|
||||
using PyClass = AudioEvent;
|
||||
|
||||
py::class_<PyClass>(*m, "AudioEvent")
|
||||
.def_property_readonly(
|
||||
"name", [](const PyClass &self) -> std::string { return self.name; })
|
||||
.def_property_readonly(
|
||||
"index", [](const PyClass &self) -> int32_t { return self.index; })
|
||||
.def_property_readonly(
|
||||
"prob", [](const PyClass &self) -> float { return self.prob; })
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindAudioTagging(py::module *m) {
|
||||
PybindAudioTaggingConfig(m);
|
||||
PybindAudioEvent(m);
|
||||
|
||||
using PyClass = AudioTagging;
|
||||
|
||||
py::class_<PyClass>(*m, "AudioTagging")
|
||||
.def(py::init<const AudioTaggingConfig &>(), py::arg("config"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("create_stream", &PyClass::CreateStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/audio-tagging.h
Normal file
16
sherpa-onnx/python/csrc/audio-tagging.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/audio-tagging.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindAudioTagging(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||
@@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string, float, float,
|
||||
const std::string &, const std::string &, float, float,
|
||||
float>(),
|
||||
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
|
||||
py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/alsa.h"
|
||||
#include "sherpa-onnx/python/csrc/audio-tagging.h"
|
||||
#include "sherpa-onnx/python/csrc/circular-buffer.h"
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
@@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
m.doc() = "pybind11 binding of sherpa-onnx";
|
||||
|
||||
PybindWaveWriter(&m);
|
||||
PybindAudioTagging(&m);
|
||||
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineCtcFstDecoderConfig(&m);
|
||||
|
||||
@@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
|
||||
using PyClass = SpeakerEmbeddingExtractorConfig;
|
||||
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
|
||||
.def(py::init<const std::string &, int32_t, bool, const std::string &>(),
|
||||
py::arg("model"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
|
||||
@@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
||||
bool, const std::string>(),
|
||||
bool, const std::string &>(),
|
||||
py::arg("whisper"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
@@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) {
|
||||
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("create_stream", &PyClass::CreateStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("compute", &PyClass::Compute,
|
||||
.def("compute", &PyClass::Compute, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from _sherpa_onnx import (
|
||||
Alsa,
|
||||
AudioEvent,
|
||||
AudioTagging,
|
||||
AudioTaggingConfig,
|
||||
AudioTaggingModelConfig,
|
||||
CircularBuffer,
|
||||
Display,
|
||||
OfflineStream,
|
||||
@@ -7,6 +11,7 @@ from _sherpa_onnx import (
|
||||
OfflineTtsConfig,
|
||||
OfflineTtsModelConfig,
|
||||
OfflineTtsVitsModelConfig,
|
||||
OfflineZipformerAudioTaggingModelConfig,
|
||||
OnlineStream,
|
||||
SileroVadModelConfig,
|
||||
SpeakerEmbeddingExtractor,
|
||||
|
||||
Reference in New Issue
Block a user