Add C++ runtime and Python API for NeMo Canary models (#2352)
This commit is contained in:
7
.github/scripts/test-python.sh
vendored
7
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,13 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test nemo canary"
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
|
||||
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
|
||||
rm sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
|
||||
python3 ./python-api-examples/offline-nemo-canary-decode-files.py
|
||||
rm -rf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
|
||||
|
||||
log "test spleeter"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
|
||||
|
||||
110
python-api-examples/offline-nemo-canary-decode-files.py
Normal file
110
python-api-examples/offline-nemo-canary-decode-files.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a non-streaming Canary model from NeMo
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
|
||||
The example model supports 4 languages and it is converted from
|
||||
https://huggingface.co/nvidia/canary-180m-flash
|
||||
|
||||
It supports automatic speech-to-text recognition (ASR) in 4 languages
|
||||
(English, German, French, Spanish) and translation from English to
|
||||
German/French/Spanish and from German/French/Spanish to English with or
|
||||
without punctuation and capitalization (PnC).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
encoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/encoder.int8.onnx"
|
||||
decoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/decoder.int8.onnx"
|
||||
tokens = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/tokens.txt"
|
||||
|
||||
en_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/en.wav"
|
||||
de_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/de.wav"
|
||||
|
||||
if not Path(encoder).is_file() or not Path(en_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OfflineRecognizer.from_nemo_canary(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
),
|
||||
en_wav,
|
||||
de_wav,
|
||||
)
|
||||
|
||||
|
||||
def decode(recognizer, samples, sample_rate, src_lang, tgt_lang):
|
||||
stream = recognizer.create_stream()
|
||||
stream.accept_waveform(sample_rate, samples)
|
||||
|
||||
recognizer.recognizer.set_config(
|
||||
config=sherpa_onnx.OfflineRecognizerConfig(
|
||||
model_config=sherpa_onnx.OfflineModelConfig(
|
||||
canary=sherpa_onnx.OfflineCanaryModelConfig(
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
recognizer.decode_stream(stream)
|
||||
return stream.result.text
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, en_wav, de_wav = create_recognizer()
|
||||
|
||||
en_audio, en_sample_rate = sf.read(en_wav, dtype="float32", always_2d=True)
|
||||
en_audio = en_audio[:, 0] # only use the first channel
|
||||
|
||||
de_audio, de_sample_rate = sf.read(de_wav, dtype="float32", always_2d=True)
|
||||
de_audio = de_audio[:, 0] # only use the first channel
|
||||
|
||||
en_wav_en_result = decode(
|
||||
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="en"
|
||||
)
|
||||
en_wav_es_result = decode(
|
||||
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="es"
|
||||
)
|
||||
en_wav_de_result = decode(
|
||||
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="de"
|
||||
)
|
||||
en_wav_fr_result = decode(
|
||||
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="fr"
|
||||
)
|
||||
|
||||
de_wav_en_result = decode(
|
||||
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="en"
|
||||
)
|
||||
de_wav_de_result = decode(
|
||||
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="de"
|
||||
)
|
||||
|
||||
print("en_wav_en_result", en_wav_en_result)
|
||||
print("en_wav_es_result", en_wav_es_result)
|
||||
print("en_wav_de_result", en_wav_de_result)
|
||||
print("en_wav_fr_result", en_wav_fr_result)
|
||||
print("-" * 10)
|
||||
print("de_wav_en_result", de_wav_en_result)
|
||||
print("de_wav_de_result", de_wav_de_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -281,9 +281,14 @@ def export_decoder(canary_model):
|
||||
|
||||
|
||||
def export_tokens(canary_model):
|
||||
underline = "▁"
|
||||
with open("./tokens.txt", "w", encoding="utf-8") as f:
|
||||
for i in range(canary_model.tokenizer.vocab_size):
|
||||
s = canary_model.tokenizer.ids_to_text([i])
|
||||
|
||||
if s[0] == " ":
|
||||
s = underline + s[1:]
|
||||
|
||||
f.write(f"{s} {i}\n")
|
||||
print("Saved to tokens.txt")
|
||||
|
||||
|
||||
@@ -289,7 +289,13 @@ def main():
|
||||
tokens.append(t)
|
||||
print("len(tokens)", len(tokens))
|
||||
print("tokens", tokens)
|
||||
|
||||
text = "".join([id2token[i] for i in tokens])
|
||||
|
||||
underline = "▁"
|
||||
# underline = b"\xe2\x96\x81".decode()
|
||||
|
||||
text = text.replace(underline, " ").strip()
|
||||
print("text:", text)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
|
||||
namespace sherpa_onnx::cxx {
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ set(sources
|
||||
jieba.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
offline-canary-model-config.cc
|
||||
offline-canary-model.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-ctc-fst-decoder.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
@@ -50,7 +52,6 @@ set(sources
|
||||
offline-rnn-lm.cc
|
||||
offline-sense-voice-model-config.cc
|
||||
offline-sense-voice-model.cc
|
||||
|
||||
offline-source-separation-impl.cc
|
||||
offline-source-separation-model-config.cc
|
||||
offline-source-separation-spleeter-model-config.cc
|
||||
@@ -58,7 +59,6 @@ set(sources
|
||||
offline-source-separation-uvr-model-config.cc
|
||||
offline-source-separation-uvr-model.cc
|
||||
offline-source-separation.cc
|
||||
|
||||
offline-stream.cc
|
||||
offline-tdnn-ctc-model.cc
|
||||
offline-tdnn-model-config.cc
|
||||
|
||||
86
sherpa-onnx/csrc/offline-canary-model-config.cc
Normal file
86
sherpa-onnx/csrc/offline-canary-model-config.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineCanaryModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("canary-encoder", &encoder,
|
||||
"Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
|
||||
|
||||
po->Register("canary-decoder", &decoder,
|
||||
"Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
|
||||
|
||||
po->Register("canary-src-lang", &src_lang,
|
||||
"Valid values: en, de, es, fr. If empty, default to use en");
|
||||
|
||||
po->Register("canary-tgt-lang", &tgt_lang,
|
||||
"Valid values: en, de, es, fr. If empty, default to use en");
|
||||
|
||||
po->Register("canary-use-pnc", &use_pnc,
|
||||
"true to enable punctuations and casing. false to disable them");
|
||||
}
|
||||
|
||||
bool OfflineCanaryModelConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --canary-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
|
||||
encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --canary-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
|
||||
decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!src_lang.empty()) {
|
||||
if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
|
||||
src_lang != "fr") {
|
||||
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tgt_lang.empty()) {
|
||||
if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
|
||||
tgt_lang != "fr") {
|
||||
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineCanaryModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineCanaryModelConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "src_lang=\"" << src_lang << "\", ";
|
||||
os << "tgt_lang=\"" << tgt_lang << "\", ";
|
||||
os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
47
sherpa-onnx/csrc/offline-canary-model-config.h
Normal file
47
sherpa-onnx/csrc/offline-canary-model-config.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCanaryModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// en, de, es, fr, or leave it empty to use en
|
||||
std::string src_lang;
|
||||
|
||||
// en, de, es, fr, or leave it empty to use en
|
||||
std::string tgt_lang;
|
||||
|
||||
// true to enable punctuations and casing
|
||||
// false to disable punctuations and casing
|
||||
bool use_pnc = true;
|
||||
|
||||
OfflineCanaryModelConfig() = default;
|
||||
OfflineCanaryModelConfig(const std::string &encoder,
|
||||
const std::string &decoder,
|
||||
const std::string &src_lang,
|
||||
const std::string &tgt_lang, bool use_pnc)
|
||||
: encoder(encoder),
|
||||
decoder(decoder),
|
||||
src_lang(src_lang),
|
||||
tgt_lang(tgt_lang),
|
||||
use_pnc(use_pnc) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
23
sherpa-onnx/csrc/offline-canary-model-meta-data.h
Normal file
23
sherpa-onnx/csrc/offline-canary-model-meta-data.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCanaryModelMetaData {
|
||||
int32_t vocab_size;
|
||||
int32_t subsampling_factor = 8;
|
||||
int32_t feat_dim = 120;
|
||||
std::string normalize_type;
|
||||
std::unordered_map<std::string, int32_t> lang2id;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
|
||||
264
sherpa-onnx/csrc/offline-canary-model.cc
Normal file
264
sherpa-onnx/csrc/offline-canary-model.cc
Normal file
@@ -0,0 +1,264 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCanaryModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.canary.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.canary.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.canary.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.canary.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
return encoder_out;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
|
||||
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states, Ort::Value enc_mask) {
|
||||
std::vector<Ort::Value> decoder_inputs;
|
||||
decoder_inputs.reserve(3 + decoder_states.size());
|
||||
|
||||
decoder_inputs.push_back(std::move(tokens));
|
||||
for (auto &s : decoder_states) {
|
||||
decoder_inputs.push_back(std::move(s));
|
||||
}
|
||||
|
||||
decoder_inputs.push_back(std::move(encoder_states));
|
||||
decoder_inputs.push_back(std::move(enc_mask));
|
||||
|
||||
auto decoder_outputs = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), decoder_inputs.data(),
|
||||
decoder_inputs.size(), decoder_output_names_ptr_.data(),
|
||||
decoder_output_names_ptr_.size());
|
||||
|
||||
Ort::Value logits = std::move(decoder_outputs[0]);
|
||||
|
||||
std::vector<Ort::Value> output_decoder_states;
|
||||
output_decoder_states.reserve(decoder_states.size());
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &s : decoder_outputs) {
|
||||
i += 1;
|
||||
if (i == 1) {
|
||||
continue;
|
||||
}
|
||||
output_decoder_states.push_back(std::move(s));
|
||||
}
|
||||
|
||||
return {std::move(logits), std::move(output_decoder_states)};
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetInitialDecoderStates() {
|
||||
std::array<int64_t, 3> shape{1, 0, 1024};
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(6);
|
||||
for (int32_t i = 0; i < 6; ++i) {
|
||||
Ort::Value state = Ort::Value::CreateTensor<float>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
|
||||
ans.push_back(std::move(state));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
const OfflineCanaryModelMetaData &GetModelMetadata() const { return meta_; }
|
||||
|
||||
OfflineCanaryModelMetaData &GetModelMetadata() { return meta_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::string model_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
||||
|
||||
if (model_type != "EncDecMultiTaskModel") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Expected model type 'EncDecMultiTaskModel'. Given: '%s'",
|
||||
model_type.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.vocab_size, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(meta_.normalize_type,
|
||||
"normalize_type");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.subsampling_factor, "subsampling_factor");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.feat_dim, "feat_dim");
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineCanaryModelMetaData meta_;
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineCanaryModel::OfflineCanaryModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineCanaryModel::OfflineCanaryModel(Manager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
OfflineCanaryModel::~OfflineCanaryModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineCanaryModel::ForwardEncoder(
|
||||
Ort::Value features, Ort::Value features_length) const {
|
||||
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OfflineCanaryModel::ForwardDecoder(Ort::Value tokens,
|
||||
std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states,
|
||||
Ort::Value enc_mask) const {
|
||||
return impl_->ForwardDecoder(std::move(tokens), std::move(decoder_states),
|
||||
std::move(encoder_states), std::move(enc_mask));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OfflineCanaryModel::GetInitialDecoderStates() const {
|
||||
return impl_->GetInitialDecoderStates();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineCanaryModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
const OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() const {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
OfflineCanaryModelMetaData &OfflineCanaryModel::GetModelMetadata() {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineCanaryModel::OfflineCanaryModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineCanaryModel::OfflineCanaryModel(
|
||||
NativeResourceManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
81
sherpa-onnx/csrc/offline-canary-model.h
Normal file
81
sherpa-onnx/csrc/offline-canary-model.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// sherpa-onnx/csrc/offline-canary-model.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// see
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py
|
||||
class OfflineCanaryModel {
|
||||
public:
|
||||
explicit OfflineCanaryModel(const OfflineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineCanaryModel(Manager *mgr, const OfflineModelConfig &config);
|
||||
|
||||
~OfflineCanaryModel();
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C) of dtype float32.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - encoder_states: A 3-D tensor of shape (N, T', encoder_dim)
|
||||
* - encoder_len: A 1-D tensor of shape (N,) containing number
|
||||
* of frames in `encoder_out` before padding.
|
||||
* Its dtype is int64_t
|
||||
* - enc_mask: A 2-D tensor of shape (N, T') with dtype bool
|
||||
*/
|
||||
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||
Ort::Value features_length) const;
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
* @param tokens A int32 tensor of shape (N, num_tokens)
|
||||
* @param decoder_states std::vector<Ort::Value>
|
||||
* @param encoder_states Output from ForwardEncoder()
|
||||
* @param enc_mask Output from ForwardEncoder()
|
||||
*
|
||||
* @return Return a pair:
|
||||
*
|
||||
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
|
||||
* - new_decoder_states: Can be used as input for ForwardDecoder()
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ForwardDecoder(
|
||||
Ort::Value tokens, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value encoder_states, Ort::Value enc_mask) const;
|
||||
|
||||
// The return value can be used as input for ForwardDecoder()
|
||||
std::vector<Ort::Value> GetInitialDecoderStates() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
const OfflineCanaryModelMetaData &GetModelMetadata() const;
|
||||
|
||||
OfflineCanaryModelMetaData &GetModelMetadata();
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_H_
|
||||
@@ -22,6 +22,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
sense_voice.Register(po);
|
||||
moonshine.Register(po);
|
||||
dolphin.Register(po);
|
||||
canary.Register(po);
|
||||
|
||||
po->Register("telespeech-ctc", &telespeech_ctc,
|
||||
"Path to model.onnx for telespeech ctc");
|
||||
@@ -114,6 +115,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return dolphin.Validate();
|
||||
}
|
||||
|
||||
if (!canary.encoder.empty()) {
|
||||
return canary.Validate();
|
||||
}
|
||||
|
||||
if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
|
||||
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
|
||||
telespeech_ctc.c_str());
|
||||
@@ -142,6 +147,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "sense_voice=" << sense_voice.ToString() << ", ";
|
||||
os << "moonshine=" << moonshine.ToString() << ", ";
|
||||
os << "dolphin=" << dolphin.ToString() << ", ";
|
||||
os << "canary=" << canary.ToString() << ", ";
|
||||
os << "telespeech_ctc=\"" << telespeech_ctc << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-dolphin-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||
@@ -32,6 +33,7 @@ struct OfflineModelConfig {
|
||||
OfflineSenseVoiceModelConfig sense_voice;
|
||||
OfflineMoonshineModelConfig moonshine;
|
||||
OfflineDolphinModelConfig dolphin;
|
||||
OfflineCanaryModelConfig canary;
|
||||
std::string telespeech_ctc;
|
||||
|
||||
std::string tokens;
|
||||
@@ -65,6 +67,7 @@ struct OfflineModelConfig {
|
||||
const OfflineSenseVoiceModelConfig &sense_voice,
|
||||
const OfflineMoonshineModelConfig &moonshine,
|
||||
const OfflineDolphinModelConfig &dolphin,
|
||||
const OfflineCanaryModelConfig &canary,
|
||||
const std::string &telespeech_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type,
|
||||
@@ -81,6 +84,7 @@ struct OfflineModelConfig {
|
||||
sense_voice(sense_voice),
|
||||
moonshine(moonshine),
|
||||
dolphin(dolphin),
|
||||
canary(canary),
|
||||
telespeech_ctc(telespeech_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
|
||||
261
sherpa-onnx/csrc/offline-recognizer-canary-impl.h
Normal file
261
sherpa-onnx/csrc/offline-recognizer-canary-impl.h
Normal file
@@ -0,0 +1,261 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-canary-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
|
||||
PostInit();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
explicit OfflineRecognizerCanaryImpl(Manager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
|
||||
PostInit();
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
auto meta = model_->GetModelMetadata();
|
||||
auto enc_out = RunEncoder(s);
|
||||
Ort::Value enc_states = std::move(enc_out[0]);
|
||||
Ort::Value enc_mask = std::move(enc_out[2]);
|
||||
// enc_out[1] is discarded
|
||||
std::vector<int32_t> decoder_input = GetInitialDecoderInput();
|
||||
auto decoder_states = model_->GetInitialDecoderStates();
|
||||
Ort::Value logits{nullptr};
|
||||
|
||||
for (int32_t i = 0; i < decoder_input.size(); ++i) {
|
||||
std::tie(logits, decoder_states) =
|
||||
RunDecoder(decoder_input[i], i, std::move(decoder_states),
|
||||
View(&enc_states), View(&enc_mask));
|
||||
}
|
||||
|
||||
int32_t max_token_id = GetMaxTokenId(&logits);
|
||||
int32_t eos = symbol_table_["<|endoftext|>"];
|
||||
|
||||
int32_t num_feature_frames =
|
||||
enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
|
||||
meta.subsampling_factor;
|
||||
|
||||
std::vector<int32_t> tokens = {max_token_id};
|
||||
|
||||
// Assume 30 tokens per second. It is to avoid the following for loop
|
||||
// running indefinitely.
|
||||
int32_t num_tokens =
|
||||
static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
|
||||
|
||||
for (int32_t i = 1; i <= num_tokens; ++i) {
|
||||
if (tokens.back() == eos) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::tie(logits, decoder_states) =
|
||||
RunDecoder(tokens.back(), i, std::move(decoder_states),
|
||||
View(&enc_states), View(&enc_mask));
|
||||
tokens.push_back(GetMaxTokenId(&logits));
|
||||
}
|
||||
|
||||
// remove the last eos token
|
||||
tokens.pop_back();
|
||||
|
||||
auto r = Convert(tokens);
|
||||
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||
|
||||
void SetConfig(const OfflineRecognizerConfig &config) override {
|
||||
config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
|
||||
config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
|
||||
config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
|
||||
|
||||
// we don't change the config_ in the base class
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : tokens) {
|
||||
if (!symbol_table_.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = symbol_table_[i];
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
int32_t GetMaxTokenId(Ort::Value *logits) const {
|
||||
// logits is of shape (1, 1, vocab_size)
|
||||
auto meta = model_->GetModelMetadata();
|
||||
const float *p_logits = logits->GetTensorData<float>();
|
||||
|
||||
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
||||
p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
|
||||
|
||||
return max_token_id;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
int64_t x_length_scalar = num_frames;
|
||||
std::array<int64_t, 1> x_length_shape = {1};
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
|
||||
x_length_shape.data(), x_length_shape.size());
|
||||
return model_->ForwardEncoder(std::move(x), std::move(x_length));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
|
||||
int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
|
||||
Ort::Value enc_states, Ort::Value enc_mask) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> shape = {1, 2};
|
||||
std::array<int32_t, 2> _decoder_input = {token, pos};
|
||||
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor(
|
||||
memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
|
||||
shape.size());
|
||||
|
||||
return model_->ForwardDecoder(std::move(decoder_input),
|
||||
std::move(decoder_states),
|
||||
std::move(enc_states), std::move(enc_mask));
|
||||
}
|
||||
|
||||
// see
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
|
||||
std::vector<int32_t> GetInitialDecoderInput() const {
|
||||
auto canary_config = config_.model_config.canary;
|
||||
const auto &meta = model_->GetModelMetadata();
|
||||
|
||||
std::vector<int32_t> decoder_input(9);
|
||||
decoder_input[0] = symbol_table_["<|startofcontext|>"];
|
||||
decoder_input[1] = symbol_table_["<|startoftranscript|>"];
|
||||
decoder_input[2] = symbol_table_["<|emo:undefined|>"];
|
||||
|
||||
if (canary_config.src_lang.empty() ||
|
||||
!meta.lang2id.count(canary_config.src_lang)) {
|
||||
decoder_input[3] = meta.lang2id.at("en");
|
||||
} else {
|
||||
decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
|
||||
}
|
||||
|
||||
if (canary_config.tgt_lang.empty() ||
|
||||
!meta.lang2id.count(canary_config.tgt_lang)) {
|
||||
decoder_input[4] = meta.lang2id.at("en");
|
||||
} else {
|
||||
decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
|
||||
}
|
||||
|
||||
if (canary_config.use_pnc) {
|
||||
decoder_input[5] = symbol_table_["<|pnc|>"];
|
||||
} else {
|
||||
decoder_input[5] = symbol_table_["<|nopnc|>"];
|
||||
}
|
||||
|
||||
decoder_input[6] = symbol_table_["<|noitn|>"];
|
||||
decoder_input[7] = symbol_table_["<|notimestamp|>"];
|
||||
decoder_input[8] = symbol_table_["<|nodiarize|>"];
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
auto &meta = model_->GetModelMetadata();
|
||||
config_.feat_config.feature_dim = meta.feat_dim;
|
||||
|
||||
config_.feat_config.nemo_normalize_type = meta.normalize_type;
|
||||
|
||||
config_.feat_config.dither = 0;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
config_.feat_config.low_freq = 0;
|
||||
config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.is_librosa = true;
|
||||
|
||||
meta.lang2id["en"] = symbol_table_["<|en|>"];
|
||||
meta.lang2id["es"] = symbol_table_["<|es|>"];
|
||||
meta.lang2id["de"] = symbol_table_["<|de|>"];
|
||||
meta.lang2id["fr"] = symbol_table_["<|fr|>"];
|
||||
|
||||
if (symbol_table_.NumSymbols() != meta.vocab_size) {
|
||||
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
|
||||
symbol_table_.NumSymbols(), meta.vocab_size);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineCanaryModel> model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
|
||||
@@ -39,7 +39,7 @@ static OfflineRecognitionResult Convert(
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = text;
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-canary-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
||||
@@ -66,6 +67,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||
}
|
||||
|
||||
if (!config.model_config.canary.encoder.empty()) {
|
||||
return std::make_unique<OfflineRecognizerCanaryImpl>(config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
@@ -252,6 +257,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (!config.model_config.canary.encoder.empty()) {
|
||||
return std::make_unique<OfflineRecognizerCanaryImpl>(mgr, config);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Refactor it. We only need to use model type for the
|
||||
// following models:
|
||||
// 1. transducer and nemo_transducer
|
||||
|
||||
@@ -183,6 +183,10 @@ Ort::Value View(Ort::Value *v) {
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<bool>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
|
||||
@@ -9,6 +9,7 @@ set(srcs
|
||||
features.cc
|
||||
homophone-replacer.cc
|
||||
keyword-spotter.cc
|
||||
offline-canary-model-config.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-dolphin-model-config.cc
|
||||
offline-fire-red-asr-model-config.cc
|
||||
|
||||
30
sherpa-onnx/python/csrc/offline-canary-model-config.cc
Normal file
30
sherpa-onnx/python/csrc/offline-canary-model-config.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
// sherpa-onnx/python/csrc/offline-canary-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineCanaryModelConfig(py::module *m) {
|
||||
using PyClass = OfflineCanaryModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineCanaryModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &, bool>(),
|
||||
py::arg("encoder") = "", py::arg("decoder") = "",
|
||||
py::arg("src_lang") = "", py::arg("tgt_lang") = "",
|
||||
py::arg("use_pnc") = true)
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("src_lang", &PyClass::src_lang)
|
||||
.def_readwrite("tgt_lang", &PyClass::tgt_lang)
|
||||
.def_readwrite("use_pnc", &PyClass::use_pnc)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-canary-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-canary-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-canary-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineCanaryModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-canary-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-dolphin-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||
@@ -34,6 +35,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineSenseVoiceModelConfig(m);
|
||||
PybindOfflineMoonshineModelConfig(m);
|
||||
PybindOfflineDolphinModelConfig(m);
|
||||
PybindOfflineCanaryModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
@@ -47,7 +49,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
const OfflineWenetCtcModelConfig &,
|
||||
const OfflineSenseVoiceModelConfig &,
|
||||
const OfflineMoonshineModelConfig &,
|
||||
const OfflineDolphinModelConfig &, const std::string &,
|
||||
const OfflineDolphinModelConfig &,
|
||||
const OfflineCanaryModelConfig &, const std::string &,
|
||||
const std::string &, int32_t, bool, const std::string &,
|
||||
const std::string &, const std::string &,
|
||||
const std::string &>(),
|
||||
@@ -62,8 +65,9 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
||||
py::arg("moonshine") = OfflineMoonshineModelConfig(),
|
||||
py::arg("dolphin") = OfflineDolphinModelConfig(),
|
||||
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("canary") = OfflineCanaryModelConfig(),
|
||||
py::arg("telespeech_ctc") = "", py::arg("tokens") = "",
|
||||
py::arg("num_threads") = 1, py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
@@ -77,6 +81,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def_readwrite("sense_voice", &PyClass::sense_voice)
|
||||
.def_readwrite("moonshine", &PyClass::moonshine)
|
||||
.def_readwrite("dolphin", &PyClass::dolphin)
|
||||
.def_readwrite("canary", &PyClass::canary)
|
||||
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
|
||||
@@ -19,7 +19,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float, const std::string &, const std::string &,
|
||||
const HomophoneReplacerConfig &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("feat_config") = FeatureExtractorConfig(),
|
||||
py::arg("model_config") = OfflineModelConfig(),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
@@ -61,6 +62,8 @@ void PybindOfflineRecognizer(py::module *m) {
|
||||
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("set_config", &PyClass::SetConfig, py::arg("config"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"decode_streams",
|
||||
[](const PyClass &self, std::vector<OfflineStream *> ss) {
|
||||
|
||||
@@ -8,9 +8,22 @@ from _sherpa_onnx import (
|
||||
DenoisedAudio,
|
||||
FastClustering,
|
||||
FastClusteringConfig,
|
||||
FeatureExtractorConfig,
|
||||
HomophoneReplacerConfig,
|
||||
OfflineCanaryModelConfig,
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineDolphinModelConfig,
|
||||
OfflineFireRedAsrModelConfig,
|
||||
OfflineLMConfig,
|
||||
OfflineModelConfig,
|
||||
OfflineMoonshineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
OfflinePunctuation,
|
||||
OfflinePunctuationConfig,
|
||||
OfflinePunctuationModelConfig,
|
||||
OfflineRecognizerConfig,
|
||||
OfflineSenseVoiceModelConfig,
|
||||
OfflineSourceSeparation,
|
||||
OfflineSourceSeparationConfig,
|
||||
OfflineSourceSeparationModelConfig,
|
||||
@@ -27,13 +40,18 @@ from _sherpa_onnx import (
|
||||
OfflineSpeechDenoiserGtcrnModelConfig,
|
||||
OfflineSpeechDenoiserModelConfig,
|
||||
OfflineStream,
|
||||
OfflineTdnnModelConfig,
|
||||
OfflineTransducerModelConfig,
|
||||
OfflineTts,
|
||||
OfflineTtsConfig,
|
||||
OfflineTtsKokoroModelConfig,
|
||||
OfflineTtsMatchaModelConfig,
|
||||
OfflineTtsModelConfig,
|
||||
OfflineTtsVitsModelConfig,
|
||||
OfflineWenetCtcModelConfig,
|
||||
OfflineWhisperModelConfig,
|
||||
OfflineZipformerAudioTaggingModelConfig,
|
||||
OfflineZipformerCtcModelConfig,
|
||||
OnlinePunctuation,
|
||||
OnlinePunctuationConfig,
|
||||
OnlinePunctuationModelConfig,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
||||
from _sherpa_onnx import (
|
||||
FeatureExtractorConfig,
|
||||
HomophoneReplacerConfig,
|
||||
OfflineCanaryModelConfig,
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineDolphinModelConfig,
|
||||
OfflineFireRedAsrModelConfig,
|
||||
@@ -425,7 +426,6 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="nemo_ctc",
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
@@ -690,6 +690,102 @@ class OfflineRecognizer(object):
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_nemo_canary(
|
||||
cls,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
tokens: str,
|
||||
src_lang: str = "en",
|
||||
tgt_lang: str = "en",
|
||||
num_threads: int = 1,
|
||||
sample_rate: int = 16000,
|
||||
feature_dim: int = 128, # not used
|
||||
decoding_method: str = "greedy_search", # not used
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/nemo/index.html>`_
|
||||
to download pre-trained models for different languages.
|
||||
|
||||
Args:
|
||||
encoder:
|
||||
Path to ``encoder.onnx`` or ``encoder.int8.onnx``.
|
||||
decoder:
|
||||
Path to ``decoder.onnx`` or ``decoder.int8.onnx``.
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
src_lang:
|
||||
The language of the input audio. Valid values are: en, es, de, fr.
|
||||
If you leave it empty, it uses en internally.
|
||||
tgt_lang:
|
||||
The language of the output text. Valid values are: en, es, de, fr.
|
||||
If you leave it empty, it uses en internally.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
Sample rate of the training data used to train the model. Not used
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model. Not used
|
||||
decoding_method:
|
||||
Valid values are greedy_search. Not used
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
canary=OfflineCanaryModelConfig(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
recognizer_config = OfflineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_whisper(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user