This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/sherpa-onnx/csrc/offline-recognizer-canary-impl.h

262 lines
8.4 KiB
C++

// sherpa-onnx/csrc/offline-recognizer-canary-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_
#include <algorithm>
#include <ios>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-canary-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineCanaryModel>(config_.model_config)) {
PostInit();
}
template <typename Manager>
explicit OfflineRecognizerCanaryImpl(Manager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(
std::make_unique<OfflineCanaryModel>(mgr, config_.model_config)) {
PostInit();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
DecodeStream(ss[i]);
}
}
void DecodeStream(OfflineStream *s) const {
auto meta = model_->GetModelMetadata();
auto enc_out = RunEncoder(s);
Ort::Value enc_states = std::move(enc_out[0]);
Ort::Value enc_mask = std::move(enc_out[2]);
// enc_out[1] is discarded
std::vector<int32_t> decoder_input = GetInitialDecoderInput();
auto decoder_states = model_->GetInitialDecoderStates();
Ort::Value logits{nullptr};
for (int32_t i = 0; i < decoder_input.size(); ++i) {
std::tie(logits, decoder_states) =
RunDecoder(decoder_input[i], i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
}
int32_t max_token_id = GetMaxTokenId(&logits);
int32_t eos = symbol_table_["<|endoftext|>"];
int32_t num_feature_frames =
enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] *
meta.subsampling_factor;
std::vector<int32_t> tokens = {max_token_id};
// Assume 30 tokens per second. It is to avoid the following for loop
// running indefinitely.
int32_t num_tokens =
static_cast<int32_t>(num_feature_frames / 100.0 * 30) + 1;
for (int32_t i = 1; i <= num_tokens; ++i) {
if (tokens.back() == eos) {
break;
}
std::tie(logits, decoder_states) =
RunDecoder(tokens.back(), i, std::move(decoder_states),
View(&enc_states), View(&enc_mask));
tokens.push_back(GetMaxTokenId(&logits));
}
// remove the last eos token
tokens.pop_back();
auto r = Convert(tokens);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.canary.src_lang = config.model_config.canary.src_lang;
config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang;
config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc;
// we don't change the config_ in the base class
}
private:
OfflineRecognitionResult Convert(const std::vector<int32_t> &tokens) const {
OfflineRecognitionResult r;
r.tokens.reserve(tokens.size());
std::string text;
for (auto i : tokens) {
if (!symbol_table_.Contains(i)) {
continue;
}
const auto &s = symbol_table_[i];
text += s;
r.tokens.push_back(s);
}
r.text = std::move(text);
return r;
}
int32_t GetMaxTokenId(Ort::Value *logits) const {
// logits is of shape (1, 1, vocab_size)
auto meta = model_->GetModelMetadata();
const float *p_logits = logits->GetTensorData<float>();
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + meta.vocab_size)));
return max_token_id;
}
std::vector<Ort::Value> RunEncoder(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
return model_->ForwardEncoder(std::move(x), std::move(x_length));
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunDecoder(
int32_t token, int32_t pos, std::vector<Ort::Value> decoder_states,
Ort::Value enc_states, Ort::Value enc_mask) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape = {1, 2};
std::array<int32_t, 2> _decoder_input = {token, pos};
Ort::Value decoder_input = Ort::Value::CreateTensor(
memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(),
shape.size());
return model_->ForwardDecoder(std::move(decoder_input),
std::move(decoder_states),
std::move(enc_states), std::move(enc_mask));
}
// see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242
std::vector<int32_t> GetInitialDecoderInput() const {
auto canary_config = config_.model_config.canary;
const auto &meta = model_->GetModelMetadata();
std::vector<int32_t> decoder_input(9);
decoder_input[0] = symbol_table_["<|startofcontext|>"];
decoder_input[1] = symbol_table_["<|startoftranscript|>"];
decoder_input[2] = symbol_table_["<|emo:undefined|>"];
if (canary_config.src_lang.empty() ||
!meta.lang2id.count(canary_config.src_lang)) {
decoder_input[3] = meta.lang2id.at("en");
} else {
decoder_input[3] = meta.lang2id.at(canary_config.src_lang);
}
if (canary_config.tgt_lang.empty() ||
!meta.lang2id.count(canary_config.tgt_lang)) {
decoder_input[4] = meta.lang2id.at("en");
} else {
decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang);
}
if (canary_config.use_pnc) {
decoder_input[5] = symbol_table_["<|pnc|>"];
} else {
decoder_input[5] = symbol_table_["<|nopnc|>"];
}
decoder_input[6] = symbol_table_["<|noitn|>"];
decoder_input[7] = symbol_table_["<|notimestamp|>"];
decoder_input[8] = symbol_table_["<|nodiarize|>"];
return decoder_input;
}
private:
void PostInit() {
auto &meta = model_->GetModelMetadata();
config_.feat_config.feature_dim = meta.feat_dim;
config_.feat_config.nemo_normalize_type = meta.normalize_type;
config_.feat_config.dither = 0;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.low_freq = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.is_librosa = true;
meta.lang2id["en"] = symbol_table_["<|en|>"];
meta.lang2id["es"] = symbol_table_["<|es|>"];
meta.lang2id["de"] = symbol_table_["<|de|>"];
meta.lang2id["fr"] = symbol_table_["<|fr|>"];
if (symbol_table_.NumSymbols() != meta.vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), meta.vocab_size);
SHERPA_ONNX_EXIT(-1);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineCanaryModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_