Add C++ runtime and Python API for NeMo Canary models (#2352)

This commit is contained in:
Fangjun Kuang
2025-07-07 17:03:49 +08:00
committed by GitHub
parent f8d957a24b
commit 0e738c356c
24 changed files with 1091 additions and 8 deletions

View File

@@ -25,6 +25,8 @@ set(sources
jieba.cc
keyword-spotter-impl.cc
keyword-spotter.cc
offline-canary-model-config.cc
offline-canary-model.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
@@ -50,7 +52,6 @@ set(sources
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc
offline-source-separation-impl.cc
offline-source-separation-model-config.cc
offline-source-separation-spleeter-model-config.cc
@@ -58,7 +59,6 @@ set(sources
offline-source-separation-uvr-model-config.cc
offline-source-separation-uvr-model.cc
offline-source-separation.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc

View File

@@ -0,0 +1,86 @@
// sherpa-onnx/csrc/offline-canary-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineCanaryModelConfig::Register(ParseOptions *po) {
po->Register("canary-encoder", &encoder,
"Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
po->Register("canary-decoder", &decoder,
"Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
po->Register("canary-src-lang", &src_lang,
"Valid values: en, de, es, fr. If empty, default to use en");
po->Register("canary-tgt-lang", &tgt_lang,
"Valid values: en, de, es, fr. If empty, default to use en");
po->Register("canary-use-pnc", &use_pnc,
"true to enable punctuations and casing. false to disable them");
}
bool OfflineCanaryModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --canary-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --canary-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
decoder.c_str());
return false;
}
if (!src_lang.empty()) {
if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
src_lang != "fr") {
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
return false;
}
}
if (!tgt_lang.empty()) {
if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
tgt_lang != "fr") {
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
return false;
}
}
return true;
}
std::string OfflineCanaryModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineCanaryModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\", ";
os << "src_lang=\"" << src_lang << "\", ";
os << "tgt_lang=\"" << tgt_lang << "\", ";
os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,47 @@
// sherpa-onnx/csrc/offline-canary-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineCanaryModelConfig {
std::string encoder;
std::string decoder;
// en, de, es, fr, or leave it empty to use en
std::string src_lang;
// en, de, es, fr, or leave it empty to use en
std::string tgt_lang;
// true to enable punctuations and casing
// false to disable punctuations and casing
bool use_pnc = true;
OfflineCanaryModelConfig() = default;
OfflineCanaryModelConfig(const std::string &encoder,
const std::string &decoder,
const std::string &src_lang,
const std::string &tgt_lang, bool use_pnc)
: encoder(encoder),
decoder(decoder),
src_lang(src_lang),
tgt_lang(tgt_lang),
use_pnc(use_pnc) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_

View File

@@ -0,0 +1,23 @@
// sherpa-onnx/csrc/offline-canary-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
struct OfflineCanaryModelMetaData {
int32_t vocab_size;
int32_t subsampling_factor = 8;
int32_t feat_dim = 120;
std::string normalize_type;
std::unordered_map<std::string, int32_t> lang2id;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_

View File

@@ -0,0 +1,264 @@
// sherpa-onnx/csrc/offline-canary-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-canary-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineCanaryModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.canary.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.canary.decoder);
InitDecoder(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.canary.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.canary.decoder);
InitDecoder(buf.data(), buf.size());
}
}
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
std::move(features_length)};
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
encoder_inputs.size(), encoder_output_names_ptr_.data(),
encoder_output_names_ptr_.size());
return encoder_out;
}
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states, Ort::Value enc_mask) {
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(3 + decoder_states.size());
decoder_inputs.push_back(std::move(tokens));
for (auto &s : decoder_states) {
decoder_inputs.push_back(std::move(s));
}
decoder_inputs.push_back(std::move(encoder_states));
decoder_inputs.push_back(std::move(enc_mask));
auto decoder_outputs = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
decoder_inputs.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
Ort::Value logits = std::move(decoder_outputs[0]);
std::vector<Ort::Value> output_decoder_states;
output_decoder_states.reserve(decoder_states.size());
int32_t i = 0;
for (auto &s : decoder_outputs) {
i += 1;
if (i == 1) {
continue;
}
output_decoder_states.push_back(std::move(s));
}
return {std::move(logits), std::move(output_decoder_states)};
}
std::vector<Ort::Value> GetInitialDecoderStates() {
std::array<int64_t, 3> shape{1, 0, 1024};
std::vector<Ort::Value> ans;
ans.reserve(6);
for (int32_t i = 0; i < 6; ++i) {
Ort::Value state = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
ans.push_back(std::move(state));
}
return ans;
}
OrtAllocator *Allocator() { return allocator_; }
const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; }
OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
if (model_type != "EncDecMultiTaskModel") {
SHERPA_ONNX_LOGE(
"Expected model type 'EncDecMultiTaskModel'. Given: '%s'",
model_type.c_str());
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size");
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type,
"normalize_type");
SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim");
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OfflineCanaryModelMetaData meta_;
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
};
OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineCanaryModel::OfflineCanaryModel(Manager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineCanaryModel::~OfflineCanaryModel() = default;
std::vector<Ort::Value> OfflineCanaryModel::ForwardEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>>
OfflineCanaryModel::ForwardDecoder(Ort::Value tokens,
std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states,
Ort::Value enc_mask) const {
return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states),
std::move(encoder_states), std::move(enc_mask));
}
std::vector<Ort::Value> OfflineCanaryModel::GetInitialDecoderStates() const {
return impl_->GetInitialDecoderStates();
}
OrtAllocator *OfflineCanaryModel::Allocator() const {
return impl_->Allocator();
}
const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const {
return impl_->GetModelMetadata();
}
OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() {
return impl_->GetModelMetadata();
}
#if __ANDROID_API__ >= 9
template OfflineCanaryModel::OfflineCanaryModel(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
#if __OHOS__
template OfflineCanaryModel::OfflineCanaryModel(
NativeResourceManager *mgr, const OfflineModelConfig &config);
#endif
} // namespace sherpa_onnx

View File

@@ -0,0 +1,81 @@
// sherpa-onnx/csrc/offline-canary-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
// see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py
class OfflineCanaryModel {
public:
explicit OfflineCanaryModel(const OfflineModelConfig &config);
template <typename Manager>
OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config);
~OfflineCanaryModel();
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C) of dtype float32.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - encoder_states: A 3-D tensor of shape (N, T', encoder_dim)
* - encoder_len: A 1-D tensor of shape (N,) containing number
* of frames in `encoder_out` before padding.
* Its dtype is int64_t
* - enc_mask: A 2-D tensor of shape (N, T') with dtype bool
*/
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) const;
/** Run the decoder model.
*
* @param tokens A int32 tensor of shape (N, num_tokens)
* @param decoder_states std::vector<Ort::Value>
* @param encoder_states Output from ForwardEncoder()
* @param enc_mask Output from ForwardEncoder()
*
* @return Return a pair:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - new_decoder_states: Can be used as input for ForwardDecoder()
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
Ort::Value encoder_states, Ort::Value enc_mask) const;
// The return value can be used as input for ForwardDecoder()
std::vector<Ort::Value> GetInitialDecoderStates() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
const OfflineCanaryModelMetaData &GetModelMetadata() const;
OfflineCanaryModelMetaData &GetModelMetadata();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_

View File

@@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
sense_voice.Register(po);
moonshine.Register(po);
dolphin.Register(po);
canary.Register(po);
po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
@@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const {
return dolphin.Validate();
}
if (!canary.encoder.empty()) {
return canary.Validate();
}
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
@@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const {
os << "sense_voice=" << sense_voice.ToString() << ", ";
os << "moonshine=" << moonshine.ToString() << ", ";
os << "dolphin=" << dolphin.ToString() << ", ";
os << "canary=" << canary.ToString() << ", ";
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";

View File

@@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
#include "sherpa-onnx/csrc/offline-dolphin-model-config.h"
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
@@ -32,6 +33,7 @@ struct OfflineModelConfig {
OfflineSenseVoiceModelConfig sense_voice;
OfflineMoonshineModelConfig moonshine;
OfflineDolphinModelConfig dolphin;
OfflineCanaryModelConfig canary;
std::string telespeech_ctc;
std::string tokens;
@@ -65,6 +67,7 @@ struct OfflineModelConfig {
const OfflineSenseVoiceModelConfig &sense_voice,
const OfflineMoonshineModelConfig &moonshine,
const OfflineDolphinModelConfig &dolphin,
const OfflineCanaryModelConfig &canary,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
@@ -81,6 +84,7 @@ struct OfflineModelConfig {
sense_voice(sense_voice),
moonshine(moonshine),
dolphin(dolphin),
canary(canary),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),

View File

@@ -0,0 +1,261 @@
// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#include <algorithm>
#include <ios>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-canary-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
PostInit();
}
template <typename Manager>
explicit OfflineRecognizerCanaryImpl(Manager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
PostInit();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
DecodeStream(ss[i]);
}
}
void DecodeStream(OfflineStream *s) const {
auto meta = model_->GetModelMetadata();
auto enc_out = RunEncoder(s);
Ort::Value enc_states = std::move(enc_out[0]);
Ort::Value enc_mask = std::move(enc_out[2]);
// enc_out[1] is discarded
std::vector<int32_t> decoder_input = GetInitialDecoderInput();
auto decoder_states = model_->GetInitialDecoderStates();
Ort::Value logits{nullptr};
for (int32_t i = 0; i < decoder_input.size(); ++i) {
std::tie(logits, decoder_states) =
RunDecoder(decoder_input[i], i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
}
int32_t max_token_id = GetMaxTokenId(&logits);
int32_t eos = symbol_table_["<|endoftext|>"];
int32_t num_feature_frames =
enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
meta.subsampling_factor;
std::vector<int32_t> tokens = {max_token_id};
// Assume 30 tokens per second. It is to avoid the following for loop
// running indefinitely.
int32_t num_tokens =
static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
for (int32_t i = 1; i <= num_tokens; ++i) {
if (tokens.back() == eos) {
break;
}
std::tie(logits, decoder_states) =
RunDecoder(tokens.back(), i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
tokens.push_back(GetMaxTokenId(&logits));
}
// remove the last eos token
tokens.pop_back();
auto r = Convert(tokens);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
// we don't change the config_ in the base class
}
private:
OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
OfflineRecognitionResult r;
r.tokens.reserve(tokens.size());
std::string text;
for (auto i : tokens) {
if (!symbol_table_.Contains(i)) {
continue;
}
const auto &s = symbol_table_[i];
text += s;
r.tokens.push_back(s);
}
r.text = std::move(text);
return r;
}
int32_t GetMaxTokenId(Ort::Value *logits) const {
// logits is of shape (1, 1, vocab_size)
auto meta = model_->GetModelMetadata();
const float *p_logits = logits->GetTensorData<float>();
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
return max_token_id;
}
std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
return model_->ForwardEncoder(std::move(x), std::move(x_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
Ort::Value enc_states, Ort::Value enc_mask) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape = {1, 2};
std::array<int32_t, 2> _decoder_input = {token, pos};
Ort::Value decoder_input = Ort::Value::CreateTensor(
memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
shape.size());
return model_->ForwardDecoder(std::move(decoder_input),
std::move(decoder_states),
std::move(enc_states), std::move(enc_mask));
}
// see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
std::vector<int32_t> GetInitialDecoderInput() const {
auto canary_config = config_.model_config.canary;
const auto &meta = model_->GetModelMetadata();
std::vector<int32_t> decoder_input(9);
decoder_input[0] = symbol_table_["<|startofcontext|>"];
decoder_input[1] = symbol_table_["<|startoftranscript|>"];
decoder_input[2] = symbol_table_["<|emo:undefined|>"];
if (canary_config.src_lang.empty() ||
!meta.lang2id.count(canary_config.src_lang)) {
decoder_input[3] = meta.lang2id.at("en");
} else {
decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
}
if (canary_config.tgt_lang.empty() ||
!meta.lang2id.count(canary_config.tgt_lang)) {
decoder_input[4] = meta.lang2id.at("en");
} else {
decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
}
if (canary_config.use_pnc) {
decoder_input[5] = symbol_table_["<|pnc|>"];
} else {
decoder_input[5] = symbol_table_["<|nopnc|>"];
}
decoder_input[6] = symbol_table_["<|noitn|>"];
decoder_input[7] = symbol_table_["<|notimestamp|>"];
decoder_input[8] = symbol_table_["<|nodiarize|>"];
return decoder_input;
}
private:
void PostInit() {
auto &meta = model_->GetModelMetadata();
config_.feat_config.feature_dim = meta.feat_dim;
config_.feat_config.nemo_normalize_type = meta.normalize_type;
config_.feat_config.dither = 0;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.low_freq = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.is_librosa = true;
meta.lang2id["en"] = symbol_table_["<|en|>"];
meta.lang2id["es"] = symbol_table_["<|es|>"];
meta.lang2id["de"] = symbol_table_["<|de|>"];
meta.lang2id["fr"] = symbol_table_["<|fr|>"];
if (symbol_table_.NumSymbols() != meta.vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), meta.vocab_size);
SHERPA_ONNX_EXIT(-1);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineCanaryModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_

View File

@@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert(
r.tokens.push_back(s);
}
r.text = text;
r.text = std::move(text);
return r;
}

View File

@@ -24,6 +24,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
@@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
}
if (!config.model_config.canary.encoder.empty()) {
return std::make_unique<OfflineRecognizerCanaryImpl>(config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer
@@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
}
if (!config.model_config.canary.encoder.empty()) {
return std::make_unique<OfflineRecognizerCanaryImpl>(mgr, config);
}
// TODO(fangjun): Refactor it. We only need to use model type for the
// following models:
// 1. transducer and nemo_transducer

View File

@@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) {
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<float>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return Ort::Value::CreateTensor(
memory_info, v->GetTensorMutableData<bool>(),
type_and_shape.GetElementCount(), shape.data(), shape.size());
default:
fprintf(stderr, "Unsupported type: %d\n",
static_cast<int32_t>(type_and_shape.GetElementType()));