Add C++ runtime and Python API for NeMo Canary models (#2352)
This commit is contained in:
@@ -25,6 +25,8 @@ set(sources
|
||||
jieba.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
offline-canary-model-config.cc
|
||||
offline-canary-model.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-ctc-fst-decoder.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
@@ -50,7 +52,6 @@ set(sources
|
||||
offline-rnn-lm.cc
|
||||
offline-sense-voice-model-config.cc
|
||||
offline-sense-voice-model.cc
|
||||
|
||||
offline-source-separation-impl.cc
|
||||
offline-source-separation-model-config.cc
|
||||
offline-source-separation-spleeter-model-config.cc
|
||||
@@ -58,7 +59,6 @@ set(sources
|
||||
offline-source-separation-uvr-model-config.cc
|
||||
offline-source-separation-uvr-model.cc
|
||||
offline-source-separation.cc
|
||||
|
||||
offline-stream.cc
|
||||
offline-tdnn-ctc-model.cc
|
||||
offline-tdnn-model-config.cc
|
||||
|
||||
86
sherpa-onnx/csrc/offline-canary-model-config.cc
Normal file
86
sherpa-onnx/csrc/offline-canary-model-config.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineCanaryModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("canary-encoder", &encoder,
|
||||
"Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
|
||||
|
||||
po->Register("canary-decoder", &decoder,
|
||||
"Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
|
||||
|
||||
po->Register("canary-src-lang", &src_lang,
|
||||
"Valid values: en, de, es, fr. If empty, default to use en");
|
||||
|
||||
po->Register("canary-tgt-lang", &tgt_lang,
|
||||
"Valid values: en, de, es, fr. If empty, default to use en");
|
||||
|
||||
po->Register("canary-use-pnc", &use_pnc,
|
||||
"true to enable punctuations and casing. false to disable them");
|
||||
}
|
||||
|
||||
bool OfflineCanaryModelConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --canary-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
|
||||
encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --canary-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
|
||||
decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!src_lang.empty()) {
|
||||
if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
|
||||
src_lang != "fr") {
|
||||
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tgt_lang.empty()) {
|
||||
if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
|
||||
tgt_lang != "fr") {
|
||||
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineCanaryModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineCanaryModelConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "src_lang=\"" << src_lang << "\", ";
|
||||
os << "tgt_lang=\"" << tgt_lang << "\", ";
|
||||
os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
47
sherpa-onnx/csrc/offline-canary-model-config.h
Normal file
47
sherpa-onnx/csrc/offline-canary-model-config.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCanaryModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// en, de, es, fr, or leave it empty to use en
|
||||
std::string src_lang;
|
||||
|
||||
// en, de, es, fr, or leave it empty to use en
|
||||
std::string tgt_lang;
|
||||
|
||||
// true to enable punctuations and casing
|
||||
// false to disable punctuations and casing
|
||||
bool use_pnc = true;
|
||||
|
||||
OfflineCanaryModelConfig() = default;
|
||||
OfflineCanaryModelConfig(const std::string &encoder,
|
||||
const std::string &decoder,
|
||||
const std::string &src_lang,
|
||||
const std::string &tgt_lang, bool use_pnc)
|
||||
: encoder(encoder),
|
||||
decoder(decoder),
|
||||
src_lang(src_lang),
|
||||
tgt_lang(tgt_lang),
|
||||
use_pnc(use_pnc) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
23
sherpa-onnx/csrc/offline-canary-model-meta-data.h
Normal file
23
sherpa-onnx/csrc/offline-canary-model-meta-data.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCanaryModelMetaData {
|
||||
int32_t vocab_size;
|
||||
int32_t subsampling_factor = 8;
|
||||
int32_t feat_dim = 120;
|
||||
std::string normalize_type;
|
||||
std::unordered_map<std::string, int32_t> lang2id;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
264
sherpa-onnx/csrc/offline-canary-model.cc
Normal file
264
sherpa-onnx/csrc/offline-canary-model.cc
Normal file
@@ -0,0 +1,264 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCanaryModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.canary.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.canary.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.canary.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.canary.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
return encoder_out;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
|
||||
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states, Ort::Value enc_mask) {
|
||||
std::vector<Ort::Value> decoder_inputs;
|
||||
decoder_inputs.reserve(3 + decoder_states.size());
|
||||
|
||||
decoder_inputs.push_back(std::move(tokens));
|
||||
for (auto &s : decoder_states) {
|
||||
decoder_inputs.push_back(std::move(s));
|
||||
}
|
||||
|
||||
decoder_inputs.push_back(std::move(encoder_states));
|
||||
decoder_inputs.push_back(std::move(enc_mask));
|
||||
|
||||
auto decoder_outputs = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
|
||||
decoder_inputs.size(), decoder_output_names_ptr_.data(),
|
||||
decoder_output_names_ptr_.size());
|
||||
|
||||
Ort::Value logits = std::move(decoder_outputs[0]);
|
||||
|
||||
std::vector<Ort::Value> output_decoder_states;
|
||||
output_decoder_states.reserve(decoder_states.size());
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &s : decoder_outputs) {
|
||||
i += 1;
|
||||
if (i == 1) {
|
||||
continue;
|
||||
}
|
||||
output_decoder_states.push_back(std::move(s));
|
||||
}
|
||||
|
||||
return {std::move(logits), std::move(output_decoder_states)};
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetInitialDecoderStates() {
|
||||
std::array<int64_t, 3> shape{1, 0, 1024};
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(6);
|
||||
for (int32_t i = 0; i < 6; ++i) {
|
||||
Ort::Value state = Ort::Value::CreateTensor<float>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
|
||||
ans.push_back(std::move(state));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; }
|
||||
|
||||
OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::string model_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
||||
|
||||
if (model_type != "EncDecMultiTaskModel") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Expected model type 'EncDecMultiTaskModel'. Given: '%s'",
|
||||
model_type.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type,
|
||||
"normalize_type");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim");
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineCanaryModelMetaData meta_;
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineCanaryModel::OfflineCanaryModel(Manager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
OfflineCanaryModel::~OfflineCanaryModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineCanaryModel::ForwardEncoder(
|
||||
Ort::Value features, Ort::Value features_length) const {
|
||||
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OfflineCanaryModel::ForwardDecoder(Ort::Value tokens,
|
||||
std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states,
|
||||
Ort::Value enc_mask) const {
|
||||
return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states),
|
||||
std::move(encoder_states), std::move(enc_mask));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OfflineCanaryModel::GetInitialDecoderStates() const {
|
||||
return impl_->GetInitialDecoderStates();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineCanaryModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineCanaryModel::OfflineCanaryModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineCanaryModel::OfflineCanaryModel(
|
||||
NativeResourceManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
81
sherpa-onnx/csrc/offline-canary-model.h
Normal file
81
sherpa-onnx/csrc/offline-canary-model.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// see
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py
|
||||
class OfflineCanaryModel {
|
||||
public:
|
||||
explicit OfflineCanaryModel(const OfflineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config);
|
||||
|
||||
~OfflineCanaryModel();
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C) of dtype float32.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - encoder_states: A 3-D tensor of shape (N, T', encoder_dim)
|
||||
* - encoder_len: A 1-D tensor of shape (N,) containing number
|
||||
* of frames in `encoder_out` before padding.
|
||||
* Its dtype is int64_t
|
||||
* - enc_mask: A 2-D tensor of shape (N, T') with dtype bool
|
||||
*/
|
||||
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||
Ort::Value features_length) const;
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
* @param tokens A int32 tensor of shape (N, num_tokens)
|
||||
* @param decoder_states std::vector<Ort::Value>
|
||||
* @param encoder_states Output from ForwardEncoder()
|
||||
* @param enc_mask Output from ForwardEncoder()
|
||||
*
|
||||
* @return Return a pair:
|
||||
*
|
||||
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
|
||||
* - new_decoder_states: Can be used as input for ForwardDecoder()
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
|
||||
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states, Ort::Value enc_mask) const;
|
||||
|
||||
// The return value can be used as input for ForwardDecoder()
|
||||
std::vector<Ort::Value> GetInitialDecoderStates() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
const OfflineCanaryModelMetaData &GetModelMetadata() const;
|
||||
|
||||
OfflineCanaryModelMetaData &GetModelMetadata();
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
@@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
sense_voice.Register(po);
|
||||
moonshine.Register(po);
|
||||
dolphin.Register(po);
|
||||
canary.Register(po);
|
||||
|
||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||
"Path to model.onnx for telespeech ctc");
|
||||
@@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return dolphin.Validate();
|
||||
}
|
||||
|
||||
if (!canary.encoder.empty()) {
|
||||
return canary.Validate();
|
||||
}
|
||||
|
||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||
telespeech_ctc.c_str());
|
||||
@@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
||||
os << "moonshine=" << moonshine.ToString() << ", ";
|
||||
os << "dolphin=" << dolphin.ToString() << ", ";
|
||||
os << "canary=" << canary.ToString() << ", ";
|
||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-dolphin-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||
@@ -32,6 +33,7 @@ struct OfflineModelConfig {
|
||||
OfflineSenseVoiceModelConfig sense_voice;
|
||||
OfflineMoonshineModelConfig moonshine;
|
||||
OfflineDolphinModelConfig dolphin;
|
||||
OfflineCanaryModelConfig canary;
|
||||
std::string telespeech_ctc;
|
||||
|
||||
std::string tokens;
|
||||
@@ -65,6 +67,7 @@ struct OfflineModelConfig {
|
||||
const OfflineSenseVoiceModelConfig &sense_voice,
|
||||
const OfflineMoonshineModelConfig &moonshine,
|
||||
const OfflineDolphinModelConfig &dolphin,
|
||||
const OfflineCanaryModelConfig &canary,
|
||||
const std::string &telespeech_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type,
|
||||
@@ -81,6 +84,7 @@ struct OfflineModelConfig {
|
||||
sense_voice(sense_voice),
|
||||
moonshine(moonshine),
|
||||
dolphin(dolphin),
|
||||
canary(canary),
|
||||
telespeech_ctc(telespeech_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
|
||||
261
sherpa-onnx/csrc/offline-recognizer-canary-impl.h
Normal file
261
sherpa-onnx/csrc/offline-recognizer-canary-impl.h
Normal file
@@ -0,0 +1,261 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-canary-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
|
||||
PostInit();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
explicit OfflineRecognizerCanaryImpl(Manager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
|
||||
PostInit();
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
auto meta = model_->GetModelMetadata();
|
||||
auto enc_out = RunEncoder(s);
|
||||
Ort::Value enc_states = std::move(enc_out[0]);
|
||||
Ort::Value enc_mask = std::move(enc_out[2]);
|
||||
// enc_out[1] is discarded
|
||||
std::vector<int32_t> decoder_input = GetInitialDecoderInput();
|
||||
auto decoder_states = model_->GetInitialDecoderStates();
|
||||
Ort::Value logits{nullptr};
|
||||
|
||||
for (int32_t i = 0; i < decoder_input.size(); ++i) {
|
||||
std::tie(logits, decoder_states) =
|
||||
RunDecoder(decoder_input[i], i, std::move(decoder_states),
|
||||
View(&enc_states), View(&enc_mask));
|
||||
}
|
||||
|
||||
int32_t max_token_id = GetMaxTokenId(&logits);
|
||||
int32_t eos = symbol_table_["<|endoftext|>"];
|
||||
|
||||
int32_t num_feature_frames =
|
||||
enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
|
||||
meta.subsampling_factor;
|
||||
|
||||
std::vector<int32_t> tokens = {max_token_id};
|
||||
|
||||
// Assume 30 tokens per second. It is to avoid the following for loop
|
||||
// running indefinitely.
|
||||
int32_t num_tokens =
|
||||
static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
|
||||
|
||||
for (int32_t i = 1; i <= num_tokens; ++i) {
|
||||
if (tokens.back() == eos) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::tie(logits, decoder_states) =
|
||||
RunDecoder(tokens.back(), i, std::move(decoder_states),
|
||||
View(&enc_states), View(&enc_mask));
|
||||
tokens.push_back(GetMaxTokenId(&logits));
|
||||
}
|
||||
|
||||
// remove the last eos token
|
||||
tokens.pop_back();
|
||||
|
||||
auto r = Convert(tokens);
|
||||
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||
|
||||
void SetConfig(const OfflineRecognizerConfig &config) override {
|
||||
config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
|
||||
config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
|
||||
config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
|
||||
|
||||
// we don't change the config_ in the base class
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : tokens) {
|
||||
if (!symbol_table_.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = symbol_table_[i];
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
int32_t GetMaxTokenId(Ort::Value *logits) const {
|
||||
// logits is of shape (1, 1, vocab_size)
|
||||
auto meta = model_->GetModelMetadata();
|
||||
const float *p_logits = logits->GetTensorData<float>();
|
||||
|
||||
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
||||
p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
|
||||
|
||||
return max_token_id;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
int64_t x_length_scalar = num_frames;
|
||||
std::array<int64_t, 1> x_length_shape = {1};
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
|
||||
x_length_shape.data(), x_length_shape.size());
|
||||
return model_->ForwardEncoder(std::move(x), std::move(x_length));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value enc_states, Ort::Value enc_mask) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> shape = {1, 2};
|
||||
std::array<int32_t, 2> _decoder_input = {token, pos};
|
||||
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor(
|
||||
memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
|
||||
shape.size());
|
||||
|
||||
return model_->ForwardDecoder(std::move(decoder_input),
|
||||
std::move(decoder_states),
|
||||
std::move(enc_states), std::move(enc_mask));
|
||||
}
|
||||
|
||||
// see
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
|
||||
std::vector<int32_t> GetInitialDecoderInput() const {
|
||||
auto canary_config = config_.model_config.canary;
|
||||
const auto &meta = model_->GetModelMetadata();
|
||||
|
||||
std::vector<int32_t> decoder_input(9);
|
||||
decoder_input[0] = symbol_table_["<|startofcontext|>"];
|
||||
decoder_input[1] = symbol_table_["<|startoftranscript|>"];
|
||||
decoder_input[2] = symbol_table_["<|emo:undefined|>"];
|
||||
|
||||
if (canary_config.src_lang.empty() ||
|
||||
!meta.lang2id.count(canary_config.src_lang)) {
|
||||
decoder_input[3] = meta.lang2id.at("en");
|
||||
} else {
|
||||
decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
|
||||
}
|
||||
|
||||
if (canary_config.tgt_lang.empty() ||
|
||||
!meta.lang2id.count(canary_config.tgt_lang)) {
|
||||
decoder_input[4] = meta.lang2id.at("en");
|
||||
} else {
|
||||
decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
|
||||
}
|
||||
|
||||
if (canary_config.use_pnc) {
|
||||
decoder_input[5] = symbol_table_["<|pnc|>"];
|
||||
} else {
|
||||
decoder_input[5] = symbol_table_["<|nopnc|>"];
|
||||
}
|
||||
|
||||
decoder_input[6] = symbol_table_["<|noitn|>"];
|
||||
decoder_input[7] = symbol_table_["<|notimestamp|>"];
|
||||
decoder_input[8] = symbol_table_["<|nodiarize|>"];
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
auto &meta = model_->GetModelMetadata();
|
||||
config_.feat_config.feature_dim = meta.feat_dim;
|
||||
|
||||
config_.feat_config.nemo_normalize_type = meta.normalize_type;
|
||||
|
||||
config_.feat_config.dither = 0;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
config_.feat_config.low_freq = 0;
|
||||
config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.is_librosa = true;
|
||||
|
||||
meta.lang2id["en"] = symbol_table_["<|en|>"];
|
||||
meta.lang2id["es"] = symbol_table_["<|es|>"];
|
||||
meta.lang2id["de"] = symbol_table_["<|de|>"];
|
||||
meta.lang2id["fr"] = symbol_table_["<|fr|>"];
|
||||
|
||||
if (symbol_table_.NumSymbols() != meta.vocab_size) {
|
||||
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
|
||||
symbol_table_.NumSymbols(), meta.vocab_size);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineCanaryModel> model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
@@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert(
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = text;
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
||||
@@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||
}
|
||||
|
||||
if (!config.model_config.canary.encoder.empty()) {
|
||||
return std::make_unique<OfflineRecognizerCanaryImpl>(config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
@@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (!config.model_config.canary.encoder.empty()) {
|
||||
return std::make_unique<OfflineRecognizerCanaryImpl>(mgr, config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
|
||||
@@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) {
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<bool>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
|
||||
Reference in New Issue
Block a user