Support audio tagging using zipformer (#747)
This commit is contained in:
32
.github/scripts/test-audio-tagging.sh
vendored
Executable file
32
.github/scripts/test-audio-tagging.sh
vendored
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run zipformer for audio tagging "
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||
ls -lh $repo
|
||||
|
||||
for w in 1.wav 2.wav 3.wav 4.wav; do
|
||||
$EXE \
|
||||
--zipformer-model=$repo/model.onnx \
|
||||
--labels=$repo/class_labels_indices.csv \
|
||||
$repo/test_wavs/$w
|
||||
done
|
||||
rm -rf $repo
|
||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -15,6 +15,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -32,6 +33,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -124,6 +126,14 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test Audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-audio-tagging
|
||||
|
||||
.github/scripts/test-audio-tagging.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -15,6 +15,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -31,6 +32,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -103,6 +105,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test Audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-audio-tagging
|
||||
|
||||
.github/scripts/test-audio-tagging.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/windows-x64.yaml
vendored
10
.github/workflows/windows-x64.yaml
vendored
@@ -14,6 +14,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -28,6 +29,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -70,6 +72,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test Audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-audio-tagging.exe
|
||||
|
||||
.github/scripts/test-audio-tagging.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
9
.github/workflows/windows-x86.yaml
vendored
9
.github/workflows/windows-x86.yaml
vendored
@@ -14,6 +14,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -28,6 +29,7 @@ on:
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -85,6 +87,13 @@ jobs:
|
||||
# export EXE=sherpa-onnx-offline-language-identification.exe
|
||||
#
|
||||
# .github/scripts/test-spoken-language-identification.sh
|
||||
- name: Test Audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-audio-tagging.exe
|
||||
|
||||
.github/scripts/test-audio-tagging.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
|
||||
@@ -46,6 +46,7 @@ def enable_alsa():
|
||||
def get_binaries():
|
||||
binaries = [
|
||||
"sherpa-onnx",
|
||||
"sherpa-onnx-offline-audio-tagging",
|
||||
"sherpa-onnx-keyword-spotter",
|
||||
"sherpa-onnx-microphone",
|
||||
"sherpa-onnx-microphone-offline",
|
||||
|
||||
2
go-api-examples/vad-asr-paraformer/.gitignore
vendored
Normal file
2
go-api-examples/vad-asr-paraformer/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go.sum
|
||||
vad-asr-paraformer
|
||||
@@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx');
|
||||
|
||||
function createOfflineTts() {
|
||||
let offlineTtsVitsModelConfig = {
|
||||
model: './vits-icefall-zh-aishell3/vits-aishell3.onnx',
|
||||
model: './vits-icefall-zh-aishell3/model.onnx',
|
||||
lexicon: './vits-icefall-zh-aishell3/lexicon.txt',
|
||||
tokens: './vits-icefall-zh-aishell3/tokens.txt',
|
||||
dataDir: '',
|
||||
|
||||
@@ -111,6 +111,16 @@ list(APPEND sources
|
||||
speaker-embedding-manager.cc
|
||||
)
|
||||
|
||||
# audio tagging
|
||||
list(APPEND sources
|
||||
audio-tagging-impl.cc
|
||||
audio-tagging-label-file.cc
|
||||
audio-tagging-model-config.cc
|
||||
audio-tagging.cc
|
||||
offline-zipformer-audio-tagging-model-config.cc
|
||||
offline-zipformer-audio-tagging-model.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND sources
|
||||
lexicon.cc
|
||||
@@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
@@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-audio-tagging
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND main_exes
|
||||
|
||||
23
sherpa-onnx/csrc/audio-tagging-impl.cc
Normal file
23
sherpa-onnx/csrc/audio-tagging-impl.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
"Please specify an audio tagging model! Return a null pointer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
29
sherpa-onnx/csrc/audio-tagging-impl.h
Normal file
29
sherpa-onnx/csrc/audio-tagging-impl.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingImpl {
|
||||
public:
|
||||
virtual ~AudioTaggingImpl() = default;
|
||||
|
||||
static std::unique_ptr<AudioTaggingImpl> Create(
|
||||
const AudioTaggingConfig &config);
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||
int32_t top_k = -1) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
|
||||
70
sherpa-onnx/csrc/audio-tagging-label-file.cc
Normal file
70
sherpa-onnx/csrc/audio-tagging-label-file.cc
Normal file
@@ -0,0 +1,70 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-label-file.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
}
|
||||
|
||||
// Format of a label file
|
||||
/*
|
||||
index,mid,display_name
|
||||
0,/m/09x0r,"Speech"
|
||||
1,/m/05zppz,"Male speech, man speaking"
|
||||
*/
|
||||
void AudioTaggingLabels::Init(std::istream &is) {
|
||||
std::string line;
|
||||
std::getline(is, line); // skip the header
|
||||
|
||||
std::string index;
|
||||
std::string tmp;
|
||||
std::string name;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
index.clear();
|
||||
name.clear();
|
||||
std::istringstream input2(line);
|
||||
|
||||
std::getline(input2, index, ',');
|
||||
std::getline(input2, tmp, ',');
|
||||
std::getline(input2, name);
|
||||
|
||||
std::size_t pos{};
|
||||
int32_t i = std::stoi(index, &pos);
|
||||
if (index.size() == 0 || pos != index.size()) {
|
||||
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (i != names_.size()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Index should be sorted and contiguous. Expected index: %d, given: "
|
||||
"%d.",
|
||||
static_cast<int32_t>(names_.size()), i);
|
||||
}
|
||||
if (name.empty() || name.front() != '"' || name.back() != '"') {
|
||||
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
names_.emplace_back(name.begin() + 1, name.end() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
const std::string &AudioTaggingLabels::GetEventName(int32_t index) const {
|
||||
return names_.at(index);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
31
sherpa-onnx/csrc/audio-tagging-label-file.h
Normal file
31
sherpa-onnx/csrc/audio-tagging-label-file.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-label-file.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
|
||||
|
||||
#include <istream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingLabels {
|
||||
public:
|
||||
explicit AudioTaggingLabels(const std::string &filename);
|
||||
|
||||
// Return the event name for the given index.
|
||||
// The returned reference is valid as long as this object is alive
|
||||
const std::string &GetEventName(int32_t index) const;
|
||||
int32_t NumEventClasses() const { return names_.size(); }
|
||||
|
||||
private:
|
||||
void Init(std::istream &is);
|
||||
|
||||
private:
|
||||
std::vector<std::string> names_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
|
||||
42
sherpa-onnx/csrc/audio-tagging-model-config.cc
Normal file
42
sherpa-onnx/csrc/audio-tagging-model-config.cc
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void AudioTaggingModelConfig::Register(ParseOptions *po) {
|
||||
zipformer.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool AudioTaggingModelConfig::Validate() const {
|
||||
if (!zipformer.model.empty() && !zipformer.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string AudioTaggingModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "AudioTaggingModelConfig(";
|
||||
os << "zipformer=" << zipformer.ToString() << ", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
39
sherpa-onnx/csrc/audio-tagging-model-config.h
Normal file
39
sherpa-onnx/csrc/audio-tagging-model-config.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct AudioTaggingModelConfig {
|
||||
struct OfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
AudioTaggingModelConfig() = default;
|
||||
|
||||
AudioTaggingModelConfig(
|
||||
const OfflineZipformerAudioTaggingModelConfig &zipformer,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
: zipformer(zipformer),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
95
sherpa-onnx/csrc/audio-tagging-zipformer-impl.h
Normal file
95
sherpa-onnx/csrc/audio-tagging-zipformer-impl.h
Normal file
@@ -0,0 +1,95 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-zipformer-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingZipformerImpl : public AudioTaggingImpl {
|
||||
public:
|
||||
explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config)
|
||||
: config_(config), model_(config.model), labels_(config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>();
|
||||
}
|
||||
|
||||
std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||
int32_t top_k = -1) const override {
|
||||
if (top_k < 0) {
|
||||
top_k = config_.top_k;
|
||||
}
|
||||
|
||||
int32_t num_event_classes = model_.NumEventClasses();
|
||||
|
||||
if (top_k > num_event_classes) {
|
||||
top_k = num_event_classes;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// WARNING(fangjun): It is fixed to 80 for all models from icefall
|
||||
int32_t feat_dim = 80;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
int64_t x_length_scalar = num_frames;
|
||||
std::array<int64_t, 1> x_length_shape = {1};
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
|
||||
x_length_shape.data(), x_length_shape.size());
|
||||
|
||||
Ort::Value probs = model_.Forward(std::move(x), std::move(x_length));
|
||||
|
||||
const float *p = probs.GetTensorData<float>();
|
||||
|
||||
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
|
||||
|
||||
std::vector<AudioEvent> ans(top_k);
|
||||
|
||||
int32_t i = 0;
|
||||
|
||||
for (int32_t index : top_k_indexes) {
|
||||
ans[i].name = labels_.GetEventName(index);
|
||||
ans[i].index = index;
|
||||
ans[i].prob = p[index];
|
||||
i += 1;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingConfig config_;
|
||||
OfflineZipformerAudioTaggingModel model_;
|
||||
AudioTaggingLabels labels_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
75
sherpa-onnx/csrc/audio-tagging.cc
Normal file
75
sherpa-onnx/csrc/audio-tagging.cc
Normal file
@@ -0,0 +1,75 @@
|
||||
// sherpa-onnx/csrc/audio-tagging.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string AudioEvent::ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "AudioEvent(";
|
||||
os << "name=\"" << name << "\", ";
|
||||
os << "index=" << index << ", ";
|
||||
os << "prob=" << prob << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void AudioTaggingConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
po->Register("labels", &labels, "Event label file");
|
||||
po->Register("top-k", &top_k, "Top k events to return in the result");
|
||||
}
|
||||
|
||||
bool AudioTaggingConfig::Validate() const {
|
||||
if (!model.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (top_k < 1) {
|
||||
SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (labels.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --labels");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(labels)) {
|
||||
SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
std::string AudioTaggingConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "AudioTaggingConfig(";
|
||||
os << "model=" << model.ToString() << ", ";
|
||||
os << "labels=\"" << labels << "\", ";
|
||||
os << "top_k=" << top_k << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
AudioTagging::AudioTagging(const AudioTaggingConfig &config)
|
||||
: impl_(AudioTaggingImpl::Create(config)) {}
|
||||
|
||||
AudioTagging::~AudioTagging() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::vector<AudioEvent> AudioTagging::Compute(OfflineStream *s,
|
||||
int32_t top_k /*= -1*/) const {
|
||||
return impl_->Compute(s, top_k);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
65
sherpa-onnx/csrc/audio-tagging.h
Normal file
65
sherpa-onnx/csrc/audio-tagging.h
Normal file
@@ -0,0 +1,65 @@
|
||||
// sherpa-onnx/csrc/audio-tagging.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct AudioTaggingConfig {
|
||||
AudioTaggingModelConfig model;
|
||||
std::string labels;
|
||||
|
||||
int32_t top_k = 5;
|
||||
|
||||
AudioTaggingConfig() = default;
|
||||
|
||||
AudioTaggingConfig(const AudioTaggingModelConfig &model,
|
||||
const std::string &labels, int32_t top_k)
|
||||
: model(model), labels(labels), top_k(top_k) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
struct AudioEvent {
|
||||
std::string name; // name of the event
|
||||
int32_t index; // index of the event in the label file
|
||||
float prob; // probability of the event
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class AudioTaggingImpl;
|
||||
|
||||
class AudioTagging {
|
||||
public:
|
||||
explicit AudioTagging(const AudioTaggingConfig &config);
|
||||
|
||||
~AudioTagging();
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
// If top_k is -1, then config.top_k is used.
|
||||
// Otherwise, config.top_k is ignored
|
||||
//
|
||||
// Return top_k AudioEvent. ans[0].prob is the largest of all returned events.
|
||||
std::vector<AudioEvent> Compute(OfflineStream *s, int32_t top_k = -1) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<AudioTaggingImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_
|
||||
@@ -97,8 +97,8 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SubtractBlank(T *in, int32_t w, int32_t h,
|
||||
int32_t blank_idx, float blank_penalty) {
|
||||
void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx,
|
||||
float blank_penalty) {
|
||||
for (int32_t i = 0; i != h; ++i) {
|
||||
in[blank_idx] -= blank_penalty;
|
||||
in += w;
|
||||
@@ -116,8 +116,7 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||
});
|
||||
|
||||
int32_t k_num = std::min<int32_t>(size, topk);
|
||||
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
|
||||
return index;
|
||||
return {vec_index.begin(), vec_index.begin() + k_num};
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -234,7 +234,7 @@ OfflineStream::OfflineStream(
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
ContextGraphPtr context_graph /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
@@ -71,10 +71,9 @@ struct WhisperTag {};
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
ContextGraphPtr context_graph = {});
|
||||
|
||||
explicit OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("zipformer-model", &model,
|
||||
"Path to zipformer model for audio tagging");
|
||||
}
|
||||
|
||||
bool OfflineZipformerAudioTaggingModelConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --zipformer-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineZipformerAudioTaggingModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineZipformerAudioTaggingModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineZipformerAudioTaggingModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineZipformerAudioTaggingModelConfig() = default;
|
||||
|
||||
explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_
|
||||
118
sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc
Normal file
118
sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc
Normal file
@@ -0,0 +1,118 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineZipformerAudioTaggingModel::Impl {
|
||||
public:
|
||||
explicit Impl(const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.zipformer.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.zipformer.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Forward(Ort::Value features, Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto ans =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(ans[0]);
|
||||
}
|
||||
|
||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
// get num_event_classes from the output[0].shape,
|
||||
// which is (N, num_event_classes)
|
||||
num_event_classes_ =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t num_event_classes_ = 0;
|
||||
};
|
||||
|
||||
OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
|
||||
const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel(
|
||||
AAssetManager *mgr, const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() =
|
||||
default;
|
||||
|
||||
Ort::Value OfflineZipformerAudioTaggingModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) const {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const {
|
||||
return impl_->NumEventClasses();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineZipformerAudioTaggingModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
64
sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h
Normal file
64
sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the zipformer CTC model of the librispeech recipe
|
||||
* from icefall.
|
||||
*
|
||||
* See
|
||||
* https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py
|
||||
*/
|
||||
class OfflineZipformerAudioTaggingModel {
|
||||
public:
|
||||
explicit OfflineZipformerAudioTaggingModel(
|
||||
const AudioTaggingModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineZipformerAudioTaggingModel(AAssetManager *mgr,
|
||||
const AudioTaggingModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineZipformerAudioTaggingModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a tensor
|
||||
* - probs: A 2-D tensor of shape (N, num_event_classes).
|
||||
*/
|
||||
Ort::Value Forward(Ort::Value features, Ort::Value features_length) const;
|
||||
|
||||
/** Return the number of event classes of the model
|
||||
*/
|
||||
int32_t NumEventClasses() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -140,9 +140,11 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
@@ -154,4 +156,8 @@ Ort::SessionOptions GetSessionOptions(
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -6,15 +6,19 @@
|
||||
#define SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
|
||||
@@ -27,7 +31,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
#endif
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
@@ -35,6 +41,8 @@ Ort::SessionOptions GetSessionOptions(
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
97
sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc
Normal file
97
sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc
Normal file
@@ -0,0 +1,97 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include <stdio.h>
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Audio tagging from a file.
|
||||
|
||||
Usage:
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
|
||||
./bin/sherpa-onnx-offline-audio-tagging \
|
||||
--zipformer-model=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx \
|
||||
--labels=./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \
|
||||
sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav
|
||||
|
||||
Input wave files should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
Please see
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||
for more models.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::AudioTaggingConfig config;
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr, "\nError: Please provide 1 wave file\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::AudioTagging tagger(config);
|
||||
std::string wav_filename = po.GetArg(1);
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
const float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
fprintf(stderr, "Start to compute\n");
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
auto stream = tagger.CreateStream();
|
||||
|
||||
stream->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
auto results = tagger.Compute(stream.get());
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Done\n");
|
||||
|
||||
int32_t i = 0;
|
||||
|
||||
for (const auto &event : results) {
|
||||
fprintf(stderr, "%d: %s\n", i, event.ToString().c_str());
|
||||
i += 1;
|
||||
}
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
|
||||
fprintf(stderr, "Wave duration: %.3f\n", duration);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user