Add Python API and Python examples for audio tagging (#753)
This commit is contained in:
9
.github/scripts/test-python.sh
vendored
9
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,15 @@ log() {
|
|||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log "test audio tagging"
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||||
|
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||||
|
python3 ./python-api-examples/audio-tagging-from-a-file.py
|
||||||
|
rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||||
|
|
||||||
|
|
||||||
log "test streaming zipformer2 ctc HLG decoding"
|
log "test streaming zipformer2 ctc HLG decoding"
|
||||||
|
|
||||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ fi
|
|||||||
if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
|
if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
|
||||||
mkdir -p $onnxruntime_dir
|
mkdir -p $onnxruntime_dir
|
||||||
pushd $onnxruntime_dir
|
pushd $onnxruntime_dir
|
||||||
# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
|
||||||
wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||||
tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||||
rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ function(download_kaldi_native_fbank)
|
|||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
||||||
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
||||||
# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
|
|
||||||
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
|
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
|
|||||||
121
python-api-examples/audio-tagging-from-a-file.py
Executable file
121
python-api-examples/audio-tagging-from-a-file.py
Executable file
@@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script shows how to use audio tagging Python APIs to tag a file.
|
||||||
|
|
||||||
|
Please read the code to download the required model files and test wave file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def read_test_wave():
|
||||||
|
# Please download the model files and test wave files from
|
||||||
|
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||||
|
test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"
|
||||||
|
|
||||||
|
if not Path(test_wave).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Please download {test_wave} from "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||||
|
)
|
||||||
|
|
||||||
|
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
|
||||||
|
data, sample_rate = sf.read(
|
||||||
|
test_wave,
|
||||||
|
always_2d=True,
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
data = data[:, 0] # use only the first channel
|
||||||
|
samples = np.ascontiguousarray(data)
|
||||||
|
|
||||||
|
# samples is a 1-d array of dtype float32
|
||||||
|
# sample_rate is a scalar
|
||||||
|
return samples, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_tagger():
|
||||||
|
# Please download the model files and test wave files from
|
||||||
|
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||||
|
model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
|
||||||
|
label_file = (
|
||||||
|
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not Path(model_file).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Please download {model_file} from "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not Path(label_file).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Please download {label_file} from "
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = sherpa_onnx.AudioTaggingConfig(
|
||||||
|
model=sherpa_onnx.AudioTaggingModelConfig(
|
||||||
|
zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
|
||||||
|
model=model_file,
|
||||||
|
),
|
||||||
|
num_threads=1,
|
||||||
|
debug=True,
|
||||||
|
provider="cpu",
|
||||||
|
),
|
||||||
|
labels=label_file,
|
||||||
|
top_k=5,
|
||||||
|
)
|
||||||
|
if not config.validate():
|
||||||
|
raise ValueError(f"Please check the config: {config}")
|
||||||
|
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
return sherpa_onnx.AudioTagging(config)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.info("Create audio tagger")
|
||||||
|
audio_tagger = create_audio_tagger()
|
||||||
|
|
||||||
|
logging.info("Read test wave")
|
||||||
|
samples, sample_rate = read_test_wave()
|
||||||
|
|
||||||
|
logging.info("Computing")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
stream = audio_tagger.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||||
|
result = audio_tagger.compute(stream)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
audio_duration = len(samples) / sample_rate
|
||||||
|
|
||||||
|
real_time_factor = elapsed_seconds / audio_duration
|
||||||
|
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||||
|
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||||
|
logging.info(
|
||||||
|
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for i, e in enumerate(result):
|
||||||
|
s += f"{i}: {e}\n"
|
||||||
|
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
include_directories(${CMAKE_SOURCE_DIR})
|
include_directories(${CMAKE_SOURCE_DIR})
|
||||||
|
|
||||||
set(srcs
|
set(srcs
|
||||||
|
audio-tagging.cc
|
||||||
circular-buffer.cc
|
circular-buffer.cc
|
||||||
display.cc
|
display.cc
|
||||||
endpoint.cc
|
endpoint.cc
|
||||||
|
|||||||
87
sherpa-onnx/python/csrc/audio-tagging.cc
Normal file
87
sherpa-onnx/python/csrc/audio-tagging.cc
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// sherpa-onnx/python/csrc/audio-tagging.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/audio-tagging.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineZipformerAudioTaggingModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const std::string &>(), py::arg("model"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void PybindAudioTaggingModelConfig(py::module *m) {
|
||||||
|
PybindOfflineZipformerAudioTaggingModelConfig(m);
|
||||||
|
|
||||||
|
using PyClass = AudioTaggingModelConfig;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
|
||||||
|
bool, const std::string &>(),
|
||||||
|
py::arg("zipformer"), py::arg("num_threads") = 1,
|
||||||
|
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||||
|
.def_readwrite("zipformer", &PyClass::zipformer)
|
||||||
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
.def_readwrite("provider", &PyClass::provider)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void PybindAudioTaggingConfig(py::module *m) {
|
||||||
|
PybindAudioTaggingModelConfig(m);
|
||||||
|
|
||||||
|
using PyClass = AudioTaggingConfig;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "AudioTaggingConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const AudioTaggingModelConfig &, const std::string &,
|
||||||
|
int32_t>(),
|
||||||
|
py::arg("model"), py::arg("labels"), py::arg("top_k") = 5)
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def_readwrite("labels", &PyClass::labels)
|
||||||
|
.def_readwrite("top_k", &PyClass::top_k)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void PybindAudioEvent(py::module *m) {
|
||||||
|
using PyClass = AudioEvent;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "AudioEvent")
|
||||||
|
.def_property_readonly(
|
||||||
|
"name", [](const PyClass &self) -> std::string { return self.name; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"index", [](const PyClass &self) -> int32_t { return self.index; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"prob", [](const PyClass &self) -> float { return self.prob; })
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindAudioTagging(py::module *m) {
|
||||||
|
PybindAudioTaggingConfig(m);
|
||||||
|
PybindAudioEvent(m);
|
||||||
|
|
||||||
|
using PyClass = AudioTagging;
|
||||||
|
|
||||||
|
py::class_<PyClass>(*m, "AudioTagging")
|
||||||
|
.def(py::init<const AudioTaggingConfig &>(), py::arg("config"),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("create_stream", &PyClass::CreateStream,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/audio-tagging.h
Normal file
16
sherpa-onnx/python/csrc/audio-tagging.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/audio-tagging.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindAudioTagging(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
|
||||||
@@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
|
|||||||
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
|
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def(py::init<const std::string &, const std::string &,
|
.def(py::init<const std::string &, const std::string &,
|
||||||
const std::string &, const std::string, float, float,
|
const std::string &, const std::string &, float, float,
|
||||||
float>(),
|
float>(),
|
||||||
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
|
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
|
||||||
py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
|
py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/alsa.h"
|
#include "sherpa-onnx/python/csrc/alsa.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/audio-tagging.h"
|
||||||
#include "sherpa-onnx/python/csrc/circular-buffer.h"
|
#include "sherpa-onnx/python/csrc/circular-buffer.h"
|
||||||
#include "sherpa-onnx/python/csrc/display.h"
|
#include "sherpa-onnx/python/csrc/display.h"
|
||||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||||
@@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
|||||||
m.doc() = "pybind11 binding of sherpa-onnx";
|
m.doc() = "pybind11 binding of sherpa-onnx";
|
||||||
|
|
||||||
PybindWaveWriter(&m);
|
PybindWaveWriter(&m);
|
||||||
|
PybindAudioTagging(&m);
|
||||||
|
|
||||||
PybindFeatures(&m);
|
PybindFeatures(&m);
|
||||||
PybindOnlineCtcFstDecoderConfig(&m);
|
PybindOnlineCtcFstDecoderConfig(&m);
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
|
|||||||
using PyClass = SpeakerEmbeddingExtractorConfig;
|
using PyClass = SpeakerEmbeddingExtractorConfig;
|
||||||
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
|
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
|
.def(py::init<const std::string &, int32_t, bool, const std::string &>(),
|
||||||
py::arg("model"), py::arg("num_threads") = 1,
|
py::arg("model"), py::arg("num_threads") = 1,
|
||||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
|
|||||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
||||||
bool, const std::string>(),
|
bool, const std::string &>(),
|
||||||
py::arg("whisper"), py::arg("num_threads") = 1,
|
py::arg("whisper"), py::arg("num_threads") = 1,
|
||||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||||
.def_readwrite("whisper", &PyClass::whisper)
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
@@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) {
|
|||||||
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||||
.def("create_stream", &PyClass::CreateStream,
|
.def("create_stream", &PyClass::CreateStream,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def("compute", &PyClass::Compute,
|
.def("compute", &PyClass::Compute, py::arg("s"),
|
||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
Alsa,
|
Alsa,
|
||||||
|
AudioEvent,
|
||||||
|
AudioTagging,
|
||||||
|
AudioTaggingConfig,
|
||||||
|
AudioTaggingModelConfig,
|
||||||
CircularBuffer,
|
CircularBuffer,
|
||||||
Display,
|
Display,
|
||||||
OfflineStream,
|
OfflineStream,
|
||||||
@@ -7,6 +11,7 @@ from _sherpa_onnx import (
|
|||||||
OfflineTtsConfig,
|
OfflineTtsConfig,
|
||||||
OfflineTtsModelConfig,
|
OfflineTtsModelConfig,
|
||||||
OfflineTtsVitsModelConfig,
|
OfflineTtsVitsModelConfig,
|
||||||
|
OfflineZipformerAudioTaggingModelConfig,
|
||||||
OnlineStream,
|
OnlineStream,
|
||||||
SileroVadModelConfig,
|
SileroVadModelConfig,
|
||||||
SpeakerEmbeddingExtractor,
|
SpeakerEmbeddingExtractor,
|
||||||
|
|||||||
Reference in New Issue
Block a user