Support CED models (#792)

This commit is contained in:
Fangjun Kuang
2024-04-19 15:20:37 +08:00
committed by GitHub
parent d97a283dbb
commit c1608b3524
33 changed files with 605 additions and 46 deletions

View File

@@ -117,6 +117,7 @@ list(APPEND sources
audio-tagging-label-file.cc
audio-tagging-model-config.cc
audio-tagging.cc
offline-ced-model.cc
offline-zipformer-audio-tagging-model-config.cc
offline-zipformer-audio-tagging-model.cc
)

View File

@@ -0,0 +1,111 @@
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ced-model.h"
namespace sherpa_onnx {
class AudioTaggingCEDImpl : public AudioTaggingImpl {
public:
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
: config_(config), model_(config.model), labels_(config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#if __ANDROID_API__ >= 9
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
const AudioTaggingConfig &config)
: config_(config),
model_(mgr, config.model),
labels_(mgr, config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(CEDTag{});
}
std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const override {
if (top_k < 0) {
top_k = config_.top_k;
}
int32_t num_event_classes = model_.NumEventClasses();
if (top_k > num_event_classes) {
top_k = num_event_classes;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// WARNING(fangjun): It is fixed to 64 for CED models
int32_t feat_dim = 64;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
Ort::Value probs = model_.Forward(std::move(x));
const float *p = probs.GetTensorData<float>();
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
std::vector<AudioEvent> ans(top_k);
int32_t i = 0;
for (int32_t index : top_k_indexes) {
ans[i].name = labels_.GetEventName(index);
ans[i].index = index;
ans[i].prob = p[index];
i += 1;
}
return ans;
}
private:
AudioTaggingConfig config_;
OfflineCEDModel model_;
AudioTaggingLabels labels_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_

View File

@@ -11,6 +11,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
#include "sherpa-onnx/csrc/macros.h"
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(config);
}
SHERPA_ONNX_LOG(
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
AAssetManager *mgr, const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
}
SHERPA_ONNX_LOG(

View File

@@ -4,11 +4,18 @@
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void AudioTaggingModelConfig::Register(ParseOptions *po) {
zipformer.Register(po);
po->Register("ced-model", &ced,
"Path to CED model. Only need to pass one of --zipformer-model "
"or --ced-model");
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
return false;
}
if (!ced.empty() && !FileExists(ced)) {
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
return false;
}
if (zipformer.model.empty() && ced.empty()) {
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
return false;
}
return true;
}
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
os << "AudioTaggingModelConfig(";
os << "zipformer=" << zipformer.ToString() << ", ";
os << "ced=\"" << ced << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";

View File

@@ -13,6 +13,7 @@ namespace sherpa_onnx {
struct AudioTaggingModelConfig {
struct OfflineZipformerAudioTaggingModelConfig zipformer;
std::string ced;
int32_t num_threads = 1;
bool debug = false;
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
AudioTaggingModelConfig(
const OfflineZipformerAudioTaggingModelConfig &zipformer,
int32_t num_threads, bool debug, const std::string &provider)
const std::string &ced, int32_t num_threads, bool debug,
const std::string &provider)
: zipformer(zipformer),
ced(ced),
num_threads(num_threads),
debug(debug),
provider(provider) {}

View File

@@ -4,6 +4,8 @@
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),

View File

@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
"inside the feature extractor");
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
"Feature dimension. Must match the one expected by the model. "
"Not used by whisper and CED models");
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");

View File

@@ -0,0 +1,112 @@
// sherpa-onnx/csrc/offline-ced-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ced-model.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineCEDModel::Impl {
public:
explicit Impl(const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.ced);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.ced);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Forward(Ort::Value features) {
features = Transpose12(allocator_, &features);
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(ans[0]);
}
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
// get num_event_classes from the output[0].shape,
// which is (N, num_event_classes)
num_event_classes_ =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
}
private:
AudioTaggingModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t num_event_classes_ = 0;
};
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineCEDModel::~OfflineCEDModel() = default;
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
return impl_->Forward(std::move(features));
}
int32_t OfflineCEDModel::NumEventClasses() const {
return impl_->NumEventClasses();
}
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
} // namespace sherpa_onnx

View File

@@ -0,0 +1,56 @@
// sherpa-onnx/csrc/offline-ced-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#include <memory>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
namespace sherpa_onnx {
/** This class implements the CED model from
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
*/
class OfflineCEDModel {
public:
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
#endif
~OfflineCEDModel();
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
*
* @return Return a tensor
* - probs: A 2-D tensor of shape (N, num_event_classes).
*/
Ort::Value Forward(Ort::Value features) const;
/** Return the number of event classes of the model
*/
int32_t NumEventClasses() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_

View File

@@ -92,15 +92,32 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
explicit Impl(WhisperTag /*tag*/) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
explicit Impl(CEDTag /*tag*/) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
opts_.frame_opts.frame_length_ms = 32;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= {}*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag)
: impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default;

View File

@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
};
struct WhisperTag {};
struct CEDTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag);
explicit OfflineStream(CEDTag tag);
~OfflineStream();
/**