Support CED models (#792)
This commit is contained in:
@@ -117,6 +117,7 @@ list(APPEND sources
|
||||
audio-tagging-label-file.cc
|
||||
audio-tagging-model-config.cc
|
||||
audio-tagging.cc
|
||||
offline-ced-model.cc
|
||||
offline-zipformer-audio-tagging-model-config.cc
|
||||
offline-zipformer-audio-tagging-model.cc
|
||||
)
|
||||
|
||||
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
@@ -0,0 +1,111 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingCEDImpl : public AudioTaggingImpl {
|
||||
public:
|
||||
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
|
||||
: config_(config), model_(config.model), labels_(config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
|
||||
const AudioTaggingConfig &config)
|
||||
: config_(config),
|
||||
model_(mgr, config.model),
|
||||
labels_(mgr, config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(CEDTag{});
|
||||
}
|
||||
|
||||
std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||
int32_t top_k = -1) const override {
|
||||
if (top_k < 0) {
|
||||
top_k = config_.top_k;
|
||||
}
|
||||
|
||||
int32_t num_event_classes = model_.NumEventClasses();
|
||||
|
||||
if (top_k > num_event_classes) {
|
||||
top_k = num_event_classes;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// WARNING(fangjun): It is fixed to 64 for CED models
|
||||
int32_t feat_dim = 64;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
assert(feat_dim * num_frames == f.size());
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
Ort::Value probs = model_.Forward(std::move(x));
|
||||
|
||||
const float *p = probs.GetTensorData<float>();
|
||||
|
||||
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
|
||||
|
||||
std::vector<AudioEvent> ans(top_k);
|
||||
|
||||
int32_t i = 0;
|
||||
|
||||
for (int32_t index : top_k_indexes) {
|
||||
ans[i].name = labels_.GetEventName(index);
|
||||
ans[i].index = index;
|
||||
ans[i].prob = p[index];
|
||||
i += 1;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingConfig config_;
|
||||
OfflineCEDModel model_;
|
||||
AudioTaggingLabels labels_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(config);
|
||||
} else if (!config.model.ced.empty()) {
|
||||
return std::make_unique<AudioTaggingCEDImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
AAssetManager *mgr, const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
|
||||
} else if (!config.model.ced.empty()) {
|
||||
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
|
||||
@@ -4,11 +4,18 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void AudioTaggingModelConfig::Register(ParseOptions *po) {
|
||||
zipformer.Register(po);
|
||||
|
||||
po->Register("ced-model", &ced,
|
||||
"Path to CED model. Only need to pass one of --zipformer-model "
|
||||
"or --ced-model");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ced.empty() && !FileExists(ced)) {
|
||||
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (zipformer.model.empty() && ced.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
|
||||
|
||||
os << "AudioTaggingModelConfig(";
|
||||
os << "zipformer=" << zipformer.ToString() << ", ";
|
||||
os << "ced=\"" << ced << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
@@ -13,6 +13,7 @@ namespace sherpa_onnx {
|
||||
|
||||
struct AudioTaggingModelConfig {
|
||||
struct OfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
std::string ced;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
|
||||
|
||||
AudioTaggingModelConfig(
|
||||
const OfflineZipformerAudioTaggingModelConfig &zipformer,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
const std::string &ced, int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: zipformer(zipformer),
|
||||
ced(ced),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
assert(feat_dim * num_frames == f.size());
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
|
||||
@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
"Feature dimension. Must match the one expected by the model. "
|
||||
"Not used by whisper and CED models");
|
||||
|
||||
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
|
||||
|
||||
|
||||
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
@@ -0,0 +1,112 @@
|
||||
// sherpa-onnx/csrc/offline-ced-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCEDModel::Impl {
|
||||
public:
|
||||
explicit Impl(const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.ced);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.ced);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Forward(Ort::Value features) {
|
||||
features = Transpose12(allocator_, &features);
|
||||
|
||||
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(ans[0]);
|
||||
}
|
||||
|
||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
// get num_event_classes from the output[0].shape,
|
||||
// which is (N, num_event_classes)
|
||||
num_event_classes_ =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t num_event_classes_ = 0;
|
||||
};
|
||||
|
||||
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
|
||||
const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineCEDModel::~OfflineCEDModel() = default;
|
||||
|
||||
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
|
||||
return impl_->Forward(std::move(features));
|
||||
}
|
||||
|
||||
int32_t OfflineCEDModel::NumEventClasses() const {
|
||||
return impl_->NumEventClasses();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
@@ -0,0 +1,56 @@
|
||||
// sherpa-onnx/csrc/offline-ced-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the CED model from
|
||||
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
|
||||
*/
|
||||
class OfflineCEDModel {
|
||||
public:
|
||||
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineCEDModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
*
|
||||
* @return Return a tensor
|
||||
* - probs: A 2-D tensor of shape (N, num_event_classes).
|
||||
*/
|
||||
Ort::Value Forward(Ort::Value features) const;
|
||||
|
||||
/** Return the number of event classes of the model
|
||||
*/
|
||||
int32_t NumEventClasses() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
@@ -92,15 +92,32 @@ class OfflineStream::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
||||
: context_graph_(context_graph) {
|
||||
explicit Impl(WhisperTag /*tag*/) {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = 80;
|
||||
opts_.mel_opts.num_bins = 80; // not used
|
||||
whisper_fbank_ =
|
||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||
}
|
||||
|
||||
explicit Impl(CEDTag /*tag*/) {
|
||||
// see
|
||||
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
||||
|
||||
opts_.frame_opts.frame_length_ms = 32;
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.preemph_coeff = 0;
|
||||
opts_.frame_opts.remove_dc_offset = false;
|
||||
opts_.frame_opts.window_type = "hann";
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
|
||||
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
||||
opts_.mel_opts.num_bins = 64;
|
||||
opts_.mel_opts.high_freq = 8000;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||
OfflineStream::OfflineStream(WhisperTag tag)
|
||||
: impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
|
||||
@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
|
||||
};
|
||||
|
||||
struct WhisperTag {};
|
||||
struct CEDTag {};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = {});
|
||||
|
||||
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
|
||||
explicit OfflineStream(WhisperTag tag);
|
||||
explicit OfflineStream(CEDTag tag);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user