From 54bc504065bfd8c07bfc2957560a9e7f7ccb550c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 19 Apr 2024 18:33:18 +0800 Subject: [PATCH] Add Python API example for CED audio tagging. (#793) --- .../audio-tagging-from-a-file-ced.py | 119 ++++++++++++++++++ sherpa-onnx/python/csrc/audio-tagging.cc | 6 +- 2 files changed, 122 insertions(+), 3 deletions(-) create mode 100755 python-api-examples/audio-tagging-from-a-file-ced.py diff --git a/python-api-examples/audio-tagging-from-a-file-ced.py b/python-api-examples/audio-tagging-from-a-file-ced.py new file mode 100755 index 00000000..27ad40a9 --- /dev/null +++ b/python-api-examples/audio-tagging-from-a-file-ced.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use audio tagging Python APIs to tag a file. + +Please read the code to download the required model files and test wave file. +""" + +import logging +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def read_test_wave(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + test_wave = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav" + + if not Path(test_wave).is_file(): + raise ValueError( + f"Please download {test_wave} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read + data, sample_rate = sf.read( + test_wave, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + + # samples is a 1-d array of dtype float32 + # sample_rate is a scalar + return samples, sample_rate + + +def create_audio_tagger(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + model_file = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx" + label_file = ( + "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv" + ) + + if not Path(model_file).is_file(): + raise ValueError( + f"Please download {model_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + if not Path(label_file).is_file(): + raise ValueError( + f"Please download {label_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + config = sherpa_onnx.AudioTaggingConfig( + model=sherpa_onnx.AudioTaggingModelConfig( + ced=model_file, + num_threads=1, + debug=True, + provider="cpu", + ), + labels=label_file, + top_k=5, + ) + if not config.validate(): + raise ValueError(f"Please check the config: {config}") + + print(config) + + return sherpa_onnx.AudioTagging(config) + + +def main(): + logging.info("Create audio tagger") + audio_tagger = create_audio_tagger() + + logging.info("Read test wave") + samples, sample_rate = read_test_wave() + + logging.info("Computing") + + start_time = time.time() + + stream = audio_tagger.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + result = audio_tagger.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + + real_time_factor = elapsed_seconds / audio_duration + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + s = "\n" + for i, e in enumerate(result): + s += f"{i}: {e}\n" + + logging.info(s) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/sherpa-onnx/python/csrc/audio-tagging.cc b/sherpa-onnx/python/csrc/audio-tagging.cc index 0808627e..336ad424 100644 --- a/sherpa-onnx/python/csrc/audio-tagging.cc +++ b/sherpa-onnx/python/csrc/audio-tagging.cc @@ -29,9 +29,9 @@ static void PybindAudioTaggingModelConfig(py::module *m) { .def(py::init<>()) .def(py::init(), - py::arg("zipformer"), py::arg("ced") = "", - py::arg("num_threads") = 1, py::arg("debug") = false, - py::arg("provider") = "cpu") + py::arg("zipformer") = OfflineZipformerAudioTaggingModelConfig{}, + py::arg("ced") = "", py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") .def_readwrite("zipformer", &PyClass::zipformer) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug)