// sherpa-onnx/csrc/offline-recognizer-canary-impl.h // // Copyright (c) 2025 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ #include #include #include #include #include #include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-canary-model.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/utils.h" namespace sherpa_onnx { class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { public: explicit OfflineRecognizerCanaryImpl(const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(config), config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config_.model_config)) { PostInit(); } template explicit OfflineRecognizerCanaryImpl(Manager *mgr, const OfflineRecognizerConfig &config) : OfflineRecognizerImpl(mgr, config), config_(config), symbol_table_(mgr, config_.model_config.tokens), model_( std::make_unique(mgr, config_.model_config)) { PostInit(); } std::unique_ptr CreateStream() const override { return std::make_unique(config_.feat_config); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { for (int32_t i = 0; i < n; ++i) { DecodeStream(ss[i]); } } void DecodeStream(OfflineStream *s) const { auto meta = model_->GetModelMetadata(); auto enc_out = RunEncoder(s); Ort::Value enc_states = std::move(enc_out[0]); Ort::Value enc_mask = std::move(enc_out[2]); // enc_out[1] is discarded std::vector decoder_input = GetInitialDecoderInput(); auto decoder_states = model_->GetInitialDecoderStates(); Ort::Value logits{nullptr}; for (int32_t i = 0; i < decoder_input.size(); ++i) { std::tie(logits, decoder_states) = RunDecoder(decoder_input[i], i, std::move(decoder_states), View(&enc_states), View(&enc_mask)); } int32_t max_token_id = GetMaxTokenId(&logits); int32_t eos = symbol_table_["<|endoftext|>"]; int32_t num_feature_frames = enc_states.GetTensorTypeAndShapeInfo().GetShape()[1] * meta.subsampling_factor; std::vector tokens = {max_token_id}; // Assume 30 tokens per second. It is to avoid the following for loop // running indefinitely. int32_t num_tokens = static_cast(num_feature_frames / 100.0 * 30) + 1; for (int32_t i = 1; i <= num_tokens; ++i) { if (tokens.back() == eos) { break; } std::tie(logits, decoder_states) = RunDecoder(tokens.back(), i, std::move(decoder_states), View(&enc_states), View(&enc_mask)); tokens.push_back(GetMaxTokenId(&logits)); } // remove the last eos token tokens.pop_back(); auto r = Convert(tokens); r.text = ApplyInverseTextNormalization(std::move(r.text)); r.text = ApplyHomophoneReplacer(std::move(r.text)); s->SetResult(r); } OfflineRecognizerConfig GetConfig() const override { return config_; } void SetConfig(const OfflineRecognizerConfig &config) override { config_.model_config.canary.src_lang = config.model_config.canary.src_lang; config_.model_config.canary.tgt_lang = config.model_config.canary.tgt_lang; config_.model_config.canary.use_pnc = config.model_config.canary.use_pnc; // we don't change the config_ in the base class } private: OfflineRecognitionResult Convert(const std::vector &tokens) const { OfflineRecognitionResult r; r.tokens.reserve(tokens.size()); std::string text; for (auto i : tokens) { if (!symbol_table_.Contains(i)) { continue; } const auto &s = symbol_table_[i]; text += s; r.tokens.push_back(s); } r.text = std::move(text); return r; } int32_t GetMaxTokenId(Ort::Value *logits) const { // logits is of shape (1, 1, vocab_size) auto meta = model_->GetModelMetadata(); const float *p_logits = logits->GetTensorData(); int32_t max_token_id = static_cast(std::distance( p_logits, std::max_element(p_logits, p_logits + meta.vocab_size))); return max_token_id; } std::vector RunEncoder(OfflineStream *s) const { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); int32_t feat_dim = config_.feat_config.feature_dim; std::vector f = s->GetFrames(); int32_t num_frames = f.size() / feat_dim; std::array shape = {1, num_frames, feat_dim}; Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), shape.data(), shape.size()); int64_t x_length_scalar = num_frames; std::array x_length_shape = {1}; Ort::Value x_length = Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1, x_length_shape.data(), x_length_shape.size()); return model_->ForwardEncoder(std::move(x), std::move(x_length)); } std::pair> RunDecoder( int32_t token, int32_t pos, std::vector decoder_states, Ort::Value enc_states, Ort::Value enc_mask) const { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); std::array shape = {1, 2}; std::array _decoder_input = {token, pos}; Ort::Value decoder_input = Ort::Value::CreateTensor( memory_info, _decoder_input.data(), _decoder_input.size(), shape.data(), shape.size()); return model_->ForwardDecoder(std::move(decoder_input), std::move(decoder_states), std::move(enc_states), std::move(enc_mask)); } // see // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/canary/test_180m_flash.py#L242 std::vector GetInitialDecoderInput() const { auto canary_config = config_.model_config.canary; const auto &meta = model_->GetModelMetadata(); std::vector decoder_input(9); decoder_input[0] = symbol_table_["<|startofcontext|>"]; decoder_input[1] = symbol_table_["<|startoftranscript|>"]; decoder_input[2] = symbol_table_["<|emo:undefined|>"]; if (canary_config.src_lang.empty() || !meta.lang2id.count(canary_config.src_lang)) { decoder_input[3] = meta.lang2id.at("en"); } else { decoder_input[3] = meta.lang2id.at(canary_config.src_lang); } if (canary_config.tgt_lang.empty() || !meta.lang2id.count(canary_config.tgt_lang)) { decoder_input[4] = meta.lang2id.at("en"); } else { decoder_input[4] = meta.lang2id.at(canary_config.tgt_lang); } if (canary_config.use_pnc) { decoder_input[5] = symbol_table_["<|pnc|>"]; } else { decoder_input[5] = symbol_table_["<|nopnc|>"]; } decoder_input[6] = symbol_table_["<|noitn|>"]; decoder_input[7] = symbol_table_["<|notimestamp|>"]; decoder_input[8] = symbol_table_["<|nodiarize|>"]; return decoder_input; } private: void PostInit() { auto &meta = model_->GetModelMetadata(); config_.feat_config.feature_dim = meta.feat_dim; config_.feat_config.nemo_normalize_type = meta.normalize_type; config_.feat_config.dither = 0; config_.feat_config.remove_dc_offset = false; config_.feat_config.low_freq = 0; config_.feat_config.window_type = "hann"; config_.feat_config.is_librosa = true; meta.lang2id["en"] = symbol_table_["<|en|>"]; meta.lang2id["es"] = symbol_table_["<|es|>"]; meta.lang2id["de"] = symbol_table_["<|de|>"]; meta.lang2id["fr"] = symbol_table_["<|fr|>"]; if (symbol_table_.NumSymbols() != meta.vocab_size) { SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", symbol_table_.NumSymbols(), meta.vocab_size); SHERPA_ONNX_EXIT(-1); } } private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; std::unique_ptr model_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_