Add C++ and Python API for FireRedASR AED models (#1867)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,3 +133,4 @@ lexicon.txt
|
|||||||
us_gold.json
|
us_gold.json
|
||||||
us_silver.json
|
us_silver.json
|
||||||
kokoro-multi-lang-v1_0
|
kokoro-multi-lang-v1_0
|
||||||
|
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
|
||||||
|
|||||||
75
python-api-examples/offline-fire-red-asr-decode-files.py
Normal file
75
python-api-examples/offline-fire-red-asr-decode-files.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file shows how to use a non-streaming FireRedAsr AED model from
|
||||||
|
https://github.com/FireRedTeam/FireRedASR
|
||||||
|
to decode files.
|
||||||
|
|
||||||
|
Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
|
||||||
|
For instance,
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
|
||||||
|
rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def create_recognizer():
|
||||||
|
encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"
|
||||||
|
decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"
|
||||||
|
tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"
|
||||||
|
test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav"
|
||||||
|
# test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav"
|
||||||
|
|
||||||
|
if (
|
||||||
|
not Path(encoder).is_file()
|
||||||
|
or not Path(decoder).is_file()
|
||||||
|
or not Path(test_wav).is_file()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"""Please download model files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
tokens=tokens,
|
||||||
|
debug=True,
|
||||||
|
),
|
||||||
|
test_wav,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
recognizer, wave_filename = create_recognizer()
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
|
||||||
|
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||||
|
# sample_rate does not need to be 16000 Hz
|
||||||
|
|
||||||
|
stream = recognizer.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate, audio)
|
||||||
|
recognizer.decode_stream(stream)
|
||||||
|
print(wave_filename)
|
||||||
|
print(stream.result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -27,6 +27,9 @@ set(sources
|
|||||||
offline-ctc-fst-decoder.cc
|
offline-ctc-fst-decoder.cc
|
||||||
offline-ctc-greedy-search-decoder.cc
|
offline-ctc-greedy-search-decoder.cc
|
||||||
offline-ctc-model.cc
|
offline-ctc-model.cc
|
||||||
|
offline-fire-red-asr-greedy-search-decoder.cc
|
||||||
|
offline-fire-red-asr-model-config.cc
|
||||||
|
offline-fire-red-asr-model.cc
|
||||||
offline-lm-config.cc
|
offline-lm-config.cc
|
||||||
offline-lm.cc
|
offline-lm.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
|
|||||||
39
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
Normal file
39
sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineFireRedAsrDecoderResult {
|
||||||
|
/// The decoded token IDs
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OfflineFireRedAsrDecoder {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineFireRedAsrDecoder() = default;
|
||||||
|
|
||||||
|
/** Run beam search given the output from the FireRedAsr encoder model.
|
||||||
|
*
|
||||||
|
* @param n_layer_cross_k A 4-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model).
|
||||||
|
* @param n_layer_cross_v A 4-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model).
|
||||||
|
*
|
||||||
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
|
*/
|
||||||
|
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
|
||||||
|
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// Note: this functions works only for batch size == 1 at present
|
||||||
|
std::vector<OfflineFireRedAsrDecoderResult>
|
||||||
|
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
|
Ort::Value cross_v) {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
// For multilingual models, initial_tokens contains [sot, language, task]
|
||||||
|
// - language is English by default
|
||||||
|
// - task is transcribe by default
|
||||||
|
//
|
||||||
|
// For non-multilingual models, initial_tokens contains [sot]
|
||||||
|
std::array<int64_t, 2> token_shape = {1, 1};
|
||||||
|
int64_t token = meta_data.sos_id;
|
||||||
|
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
|
||||||
|
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &token, 1, token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> offset_shape{1};
|
||||||
|
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||||
|
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||||
|
|
||||||
|
std::vector<OfflineFireRedAsrDecoderResult> ans(1);
|
||||||
|
|
||||||
|
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||||
|
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
decoder_out = {Ort::Value{nullptr},
|
||||||
|
std::move(self_kv_cache.first),
|
||||||
|
std::move(self_kv_cache.second),
|
||||||
|
std::move(cross_k),
|
||||||
|
std::move(cross_v),
|
||||||
|
std::move(offset)};
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < meta_data.max_len; ++i) {
|
||||||
|
decoder_out = model_->ForwardDecoder(View(&tokens),
|
||||||
|
std::move(std::get<1>(decoder_out)),
|
||||||
|
std::move(std::get<2>(decoder_out)),
|
||||||
|
std::move(std::get<3>(decoder_out)),
|
||||||
|
std::move(std::get<4>(decoder_out)),
|
||||||
|
std::move(std::get<5>(decoder_out)));
|
||||||
|
|
||||||
|
const auto &logits = std::get<0>(decoder_out);
|
||||||
|
const float *p_logits = logits.GetTensorData<float>();
|
||||||
|
|
||||||
|
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
int32_t vocab_size = logits_shape[2];
|
||||||
|
|
||||||
|
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
||||||
|
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
||||||
|
if (max_token_id == meta_data.eos_id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ans[0].tokens.push_back(max_token_id);
|
||||||
|
|
||||||
|
token = max_token_id;
|
||||||
|
|
||||||
|
// increment offset
|
||||||
|
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
|
||||||
|
public:
|
||||||
|
explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model)
|
||||||
|
: model_(model) {}
|
||||||
|
|
||||||
|
std::vector<OfflineFireRedAsrDecoderResult> Decode(
|
||||||
|
Ort::Value cross_k, Ort::Value cross_v) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineFireRedAsrModel *model_; // not owned
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_
|
||||||
56
sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
Normal file
56
sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("fire-red-asr-encoder", &encoder,
|
||||||
|
"Path to onnx encoder of FireRedAsr");
|
||||||
|
|
||||||
|
po->Register("fire-red-asr-decoder", &decoder,
|
||||||
|
"Path to onnx decoder of FireRedAsr");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineFireRedAsrModelConfig::Validate() const {
|
||||||
|
if (encoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(encoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist",
|
||||||
|
encoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decoder.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist",
|
||||||
|
decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineFireRedAsrModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineFireRedAsrModelConfig(";
|
||||||
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
|
os << "decoder=\"" << decoder << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
31
sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
Normal file
31
sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// see https://github.com/FireRedTeam/FireRedASR
|
||||||
|
struct OfflineFireRedAsrModelConfig {
|
||||||
|
std::string encoder;
|
||||||
|
std::string decoder;
|
||||||
|
|
||||||
|
OfflineFireRedAsrModelConfig() = default;
|
||||||
|
OfflineFireRedAsrModelConfig(const std::string &encoder,
|
||||||
|
const std::string &decoder)
|
||||||
|
: encoder(encoder), decoder(decoder) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
28
sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
Normal file
28
sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineFireRedAsrModelMetaData {
|
||||||
|
int32_t sos_id;
|
||||||
|
int32_t eos_id;
|
||||||
|
int32_t max_len;
|
||||||
|
|
||||||
|
int32_t num_decoder_layers;
|
||||||
|
int32_t num_head;
|
||||||
|
int32_t head_dim;
|
||||||
|
|
||||||
|
std::vector<float> mean;
|
||||||
|
std::vector<float> inv_stddev;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_
|
||||||
256
sherpa-onnx/csrc/offline-fire-red-asr-model.cc
Normal file
256
sherpa-onnx/csrc/offline-fire-red-asr-model.cc
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
#include "rawfile/raw_file_manager.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineFireRedAsrModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.fire_red_asr.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.fire_red_asr.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
Impl(Manager *mgr, const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.fire_red_asr.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.fire_red_asr.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features,
|
||||||
|
Ort::Value features_length) {
|
||||||
|
std::array<Ort::Value, 2> inputs{std::move(features),
|
||||||
|
std::move(features_length)};
|
||||||
|
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v, Ort::Value offset) {
|
||||||
|
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
|
||||||
|
std::move(n_layer_self_k_cache),
|
||||||
|
std::move(n_layer_self_v_cache),
|
||||||
|
std::move(n_layer_cross_k),
|
||||||
|
std::move(n_layer_cross_v),
|
||||||
|
std::move(offset)};
|
||||||
|
|
||||||
|
auto decoder_out = decoder_sess_->Run(
|
||||||
|
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
|
||||||
|
decoder_input.size(), decoder_output_names_ptr_.data(),
|
||||||
|
decoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value, Ort::Value>{
|
||||||
|
std::move(decoder_out[0]), std::move(decoder_out[1]),
|
||||||
|
std::move(decoder_out[2]), std::move(decoder_input[3]),
|
||||||
|
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
std::array<int64_t, 5> shape{meta_data_.num_decoder_layers, batch_size,
|
||||||
|
meta_data_.max_len, meta_data_.num_head,
|
||||||
|
meta_data_.head_dim};
|
||||||
|
|
||||||
|
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
|
||||||
|
Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
|
||||||
|
Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
auto n = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
|
||||||
|
|
||||||
|
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
|
||||||
|
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
memset(p_k, 0, sizeof(float) * n);
|
||||||
|
memset(p_v, 0, sizeof(float) * n);
|
||||||
|
|
||||||
|
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() { return allocator_; }
|
||||||
|
|
||||||
|
const OfflineFireRedAsrModelMetaData &GetModelMetadata() const {
|
||||||
|
return meta_data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---encoder---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
#if __OHOS__
|
||||||
|
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||||
|
#else
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.num_decoder_layers,
|
||||||
|
"num_decoder_layers");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.num_head, "num_head");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.head_dim, "head_dim");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.sos_id, "sos");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.eos_id, "eos");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.max_len, "max_len");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "cmvn_mean");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev,
|
||||||
|
"cmvn_inv_stddev");
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
OfflineFireRedAsrModelMetaData meta_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineFireRedAsrModel::OfflineFireRedAsrModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
OfflineFireRedAsrModel::OfflineFireRedAsrModel(Manager *mgr,
|
||||||
|
const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
|
||||||
|
OfflineFireRedAsrModel::~OfflineFireRedAsrModel() = default;
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> OfflineFireRedAsrModel::ForwardEncoder(
|
||||||
|
Ort::Value features, Ort::Value features_length) const {
|
||||||
|
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
OfflineFireRedAsrModel::ForwardDecoder(Ort::Value tokens,
|
||||||
|
Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache,
|
||||||
|
Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v,
|
||||||
|
Ort::Value offset) const {
|
||||||
|
return impl_->ForwardDecoder(
|
||||||
|
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||||
|
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||||
|
std::move(n_layer_cross_v), std::move(offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value>
|
||||||
|
OfflineFireRedAsrModel::GetInitialSelfKVCache() const {
|
||||||
|
return impl_->GetInitialSelfKVCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineFireRedAsrModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
const OfflineFireRedAsrModelMetaData &OfflineFireRedAsrModel::GetModelMetadata()
|
||||||
|
const {
|
||||||
|
return impl_->GetModelMetadata();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
|
||||||
|
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
template OfflineFireRedAsrModel::OfflineFireRedAsrModel(
|
||||||
|
NativeResourceManager *mgr, const OfflineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
92
sherpa-onnx/csrc/offline-fire-red-asr-model.h
Normal file
92
sherpa-onnx/csrc/offline-fire-red-asr-model.h
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-fire-red-asr-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-meta-data.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineFireRedAsrModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineFireRedAsrModel(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
OfflineFireRedAsrModel(Manager *mgr, const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
~OfflineFireRedAsrModel();
|
||||||
|
|
||||||
|
/** Run the encoder model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C).
|
||||||
|
* @param features_len A tensor of shape (N,) with dtype int64.
|
||||||
|
*
|
||||||
|
* @return Return a pair containing:
|
||||||
|
* - n_layer_cross_k: A 4-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model)
|
||||||
|
* - n_layer_cross_v: A 4-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model)
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> ForwardEncoder(
|
||||||
|
Ort::Value features, Ort::Value features_length) const;
|
||||||
|
|
||||||
|
/** Run the decoder model.
|
||||||
|
*
|
||||||
|
* @param tokens A int64 tensor of shape (N, num_words)
|
||||||
|
* @param n_layer_self_k_cache A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, max_len, num_head, head_dim).
|
||||||
|
* @param n_layer_self_v_cache A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, max_len, num_head, head_dim).
|
||||||
|
* @param n_layer_cross_k A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model).
|
||||||
|
* @param n_layer_cross_v A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, T, d_model).
|
||||||
|
* @param offset A int64 tensor of shape (N,)
|
||||||
|
*
|
||||||
|
* @return Return a tuple containing 6 tensors:
|
||||||
|
*
|
||||||
|
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
|
||||||
|
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
|
||||||
|
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
|
||||||
|
* - out_n_layer_cross_k Same as n_layer_cross_k
|
||||||
|
* - out_n_layer_cross_v Same as n_layer_cross_v
|
||||||
|
* - out_offset Same as offset
|
||||||
|
*/
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||||
|
|
||||||
|
/** Return the initial self kv cache in a pair
|
||||||
|
* - n_layer_self_k_cache A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, max_len, num_head, head_dim).
|
||||||
|
* - n_layer_self_v_cache A 5-D tensor of shape
|
||||||
|
* (num_decoder_layers, N, max_len, num_head, head_dim).
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
|
||||||
|
|
||||||
|
const OfflineFireRedAsrModelMetaData &GetModelMetadata() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_
|
||||||
@@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
paraformer.Register(po);
|
paraformer.Register(po);
|
||||||
nemo_ctc.Register(po);
|
nemo_ctc.Register(po);
|
||||||
whisper.Register(po);
|
whisper.Register(po);
|
||||||
|
fire_red_asr.Register(po);
|
||||||
tdnn.Register(po);
|
tdnn.Register(po);
|
||||||
zipformer_ctc.Register(po);
|
zipformer_ctc.Register(po);
|
||||||
wenet_ctc.Register(po);
|
wenet_ctc.Register(po);
|
||||||
@@ -38,7 +39,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("model-type", &model_type,
|
po->Register("model-type", &model_type,
|
||||||
"Specify it to reduce model initialization time. "
|
"Specify it to reduce model initialization time. "
|
||||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||||
"tdnn, zipformer2_ctc, telespeech_ctc."
|
"tdnn, zipformer2_ctc, telespeech_ctc, fire_red_asr."
|
||||||
"All other values lead to loading the model twice.");
|
"All other values lead to loading the model twice.");
|
||||||
po->Register("modeling-unit", &modeling_unit,
|
po->Register("modeling-unit", &modeling_unit,
|
||||||
"The modeling unit of the model, commonly used units are bpe, "
|
"The modeling unit of the model, commonly used units are bpe, "
|
||||||
@@ -84,6 +85,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return whisper.Validate();
|
return whisper.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!fire_red_asr.encoder.empty()) {
|
||||||
|
return fire_red_asr.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
if (!tdnn.model.empty()) {
|
if (!tdnn.model.empty()) {
|
||||||
return tdnn.Validate();
|
return tdnn.Validate();
|
||||||
}
|
}
|
||||||
@@ -125,6 +130,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||||
os << "whisper=" << whisper.ToString() << ", ";
|
os << "whisper=" << whisper.ToString() << ", ";
|
||||||
|
os << "fire_red_asr=" << fire_red_asr.ToString() << ", ";
|
||||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
#include "sherpa-onnx/csrc/offline-moonshine-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
@@ -23,6 +24,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineParaformerModelConfig paraformer;
|
OfflineParaformerModelConfig paraformer;
|
||||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||||
OfflineWhisperModelConfig whisper;
|
OfflineWhisperModelConfig whisper;
|
||||||
|
OfflineFireRedAsrModelConfig fire_red_asr;
|
||||||
OfflineTdnnModelConfig tdnn;
|
OfflineTdnnModelConfig tdnn;
|
||||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||||
OfflineWenetCtcModelConfig wenet_ctc;
|
OfflineWenetCtcModelConfig wenet_ctc;
|
||||||
@@ -54,6 +56,7 @@ struct OfflineModelConfig {
|
|||||||
const OfflineParaformerModelConfig ¶former,
|
const OfflineParaformerModelConfig ¶former,
|
||||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||||
const OfflineWhisperModelConfig &whisper,
|
const OfflineWhisperModelConfig &whisper,
|
||||||
|
const OfflineFireRedAsrModelConfig &fire_red_asr,
|
||||||
const OfflineTdnnModelConfig &tdnn,
|
const OfflineTdnnModelConfig &tdnn,
|
||||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||||
@@ -68,6 +71,7 @@ struct OfflineModelConfig {
|
|||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
nemo_ctc(nemo_ctc),
|
nemo_ctc(nemo_ctc),
|
||||||
whisper(whisper),
|
whisper(whisper),
|
||||||
|
fire_red_asr(fire_red_asr),
|
||||||
tdnn(tdnn),
|
tdnn(tdnn),
|
||||||
zipformer_ctc(zipformer_ctc),
|
zipformer_ctc(zipformer_ctc),
|
||||||
wenet_ctc(wenet_ctc),
|
wenet_ctc(wenet_ctc),
|
||||||
|
|||||||
158
sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
Normal file
158
sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OfflineRecognitionResult Convert(
|
||||||
|
const OfflineFireRedAsrDecoderResult &src, const SymbolTable &sym_table) {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
if (!sym_table.Contains(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &s = sym_table[i];
|
||||||
|
text += s;
|
||||||
|
r.tokens.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
r.text = text;
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerFireRedAsrImpl(
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineFireRedAsrModel>(config.model_config)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
OfflineRecognizerFireRedAsrImpl(Manager *mgr,
|
||||||
|
const OfflineRecognizerConfig &config)
|
||||||
|
: OfflineRecognizerImpl(mgr, config),
|
||||||
|
config_(config),
|
||||||
|
symbol_table_(mgr, config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineFireRedAsrModel>(mgr,
|
||||||
|
config.model_config)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init() {
|
||||||
|
if (config_.decoding_method == "greedy_search") {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OfflineFireRedAsrGreedySearchDecoder>(model_.get());
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only greedy_search is supported at present for FireRedAsr. Given %s",
|
||||||
|
config_.decoding_method.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
|
||||||
|
config_.feat_config.normalize_samples = false;
|
||||||
|
config_.feat_config.high_freq = 0;
|
||||||
|
config_.feat_config.snip_edges = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
// batch decoding is not implemented yet
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(ss[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void DecodeStream(OfflineStream *s) const {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = s->FeatureDim();
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
ApplyCMVN(&f);
|
||||||
|
|
||||||
|
int64_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape{1, num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
|
||||||
|
int64_t len_shape = 1;
|
||||||
|
Ort::Value x_len =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &num_frames, 1, &len_shape, 1);
|
||||||
|
|
||||||
|
auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len));
|
||||||
|
|
||||||
|
auto results =
|
||||||
|
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||||
|
|
||||||
|
auto r = Convert(results[0], symbol_table_);
|
||||||
|
|
||||||
|
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||||
|
s->SetResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyCMVN(std::vector<float> *v) const {
|
||||||
|
const auto &meta_data = model_->GetModelMetadata();
|
||||||
|
const auto &mean = meta_data.mean;
|
||||||
|
const auto &inv_stddev = meta_data.inv_stddev;
|
||||||
|
int32_t feat_dim = static_cast<int32_t>(mean.size());
|
||||||
|
int32_t num_frames = static_cast<int32_t>(v->size()) / feat_dim;
|
||||||
|
|
||||||
|
float *p = v->data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_frames; ++i) {
|
||||||
|
for (int32_t k = 0; k != feat_dim; ++k) {
|
||||||
|
p[k] = (p[k] - mean[k]) * inv_stddev[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
p += feat_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineFireRedAsrModel> model_;
|
||||||
|
std::unique_ptr<OfflineFireRedAsrDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_
|
||||||
@@ -24,6 +24,7 @@
|
|||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h"
|
||||||
@@ -56,6 +57,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.fire_red_asr.encoder.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerFireRedAsrImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
if (!config.model_config.moonshine.preprocessor.empty()) {
|
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(config);
|
||||||
}
|
}
|
||||||
@@ -237,6 +242,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.fire_red_asr.encoder.empty()) {
|
||||||
|
return std::make_unique<OfflineRecognizerFireRedAsrImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
if (!config.model_config.moonshine.preprocessor.empty()) {
|
if (!config.model_config.moonshine.preprocessor.empty()) {
|
||||||
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
return std::make_unique<OfflineRecognizerMoonshineImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ set(srcs
|
|||||||
features.cc
|
features.cc
|
||||||
keyword-spotter.cc
|
keyword-spotter.cc
|
||||||
offline-ctc-fst-decoder-config.cc
|
offline-ctc-fst-decoder-config.cc
|
||||||
|
offline-fire-red-asr-model-config.cc
|
||||||
offline-lm-config.cc
|
offline-lm-config.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
offline-moonshine-model-config.cc
|
offline-moonshine-model-config.cc
|
||||||
|
|||||||
24
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
Normal file
24
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-fire-red-asr-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineFireRedAsrModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineFireRedAsrModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineFireRedAsrModelConfig")
|
||||||
|
.def(py::init<const std::string &, const std::string &>(),
|
||||||
|
py::arg("encoder"), py::arg("decoder"))
|
||||||
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineFireRedAsrModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_
|
||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-fire-red-asr-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-moonshine-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
@@ -25,6 +26,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineParaformerModelConfig(m);
|
PybindOfflineParaformerModelConfig(m);
|
||||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||||
PybindOfflineWhisperModelConfig(m);
|
PybindOfflineWhisperModelConfig(m);
|
||||||
|
PybindOfflineFireRedAsrModelConfig(m);
|
||||||
PybindOfflineTdnnModelConfig(m);
|
PybindOfflineTdnnModelConfig(m);
|
||||||
PybindOfflineZipformerCtcModelConfig(m);
|
PybindOfflineZipformerCtcModelConfig(m);
|
||||||
PybindOfflineWenetCtcModelConfig(m);
|
PybindOfflineWenetCtcModelConfig(m);
|
||||||
@@ -33,35 +35,38 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
.def(
|
.def(py::init<const OfflineTransducerModelConfig &,
|
||||||
py::init<
|
const OfflineParaformerModelConfig &,
|
||||||
const OfflineTransducerModelConfig &,
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
const OfflineParaformerModelConfig &,
|
const OfflineWhisperModelConfig &,
|
||||||
const OfflineNemoEncDecCtcModelConfig &,
|
const OfflineFireRedAsrModelConfig &,
|
||||||
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
|
const OfflineTdnnModelConfig &,
|
||||||
const OfflineZipformerCtcModelConfig &,
|
const OfflineZipformerCtcModelConfig &,
|
||||||
const OfflineWenetCtcModelConfig &,
|
const OfflineWenetCtcModelConfig &,
|
||||||
const OfflineSenseVoiceModelConfig &,
|
const OfflineSenseVoiceModelConfig &,
|
||||||
const OfflineMoonshineModelConfig &, const std::string &,
|
const OfflineMoonshineModelConfig &, const std::string &,
|
||||||
const std::string &, int32_t, bool, const std::string &,
|
const std::string &, int32_t, bool, const std::string &,
|
||||||
const std::string &, const std::string &, const std::string &>(),
|
const std::string &, const std::string &,
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
const std::string &>(),
|
||||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||||
py::arg("whisper") = OfflineWhisperModelConfig(),
|
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||||
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
py::arg("whisper") = OfflineWhisperModelConfig(),
|
||||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
py::arg("fire_red_asr") = OfflineFireRedAsrModelConfig(),
|
||||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
||||||
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||||
py::arg("moonshine") = OfflineMoonshineModelConfig(),
|
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||||
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
|
||||||
py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("moonshine") = OfflineMoonshineModelConfig(),
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
py::arg("telespeech_ctc") = "", py::arg("tokens"),
|
||||||
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
|
py::arg("num_threads"), py::arg("debug") = false,
|
||||||
|
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||||
|
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||||
.def_readwrite("whisper", &PyClass::whisper)
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
|
.def_readwrite("fire_red_asr", &PyClass::fire_red_asr)
|
||||||
.def_readwrite("tdnn", &PyClass::tdnn)
|
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
|||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
FeatureExtractorConfig,
|
FeatureExtractorConfig,
|
||||||
OfflineCtcFstDecoderConfig,
|
OfflineCtcFstDecoderConfig,
|
||||||
|
OfflineFireRedAsrModelConfig,
|
||||||
OfflineLMConfig,
|
OfflineLMConfig,
|
||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
OfflineMoonshineModelConfig,
|
OfflineMoonshineModelConfig,
|
||||||
@@ -571,6 +572,78 @@ class OfflineRecognizer(object):
|
|||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fire_red_asr(
|
||||||
|
cls,
|
||||||
|
encoder: str,
|
||||||
|
decoder: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int = 1,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
|
rule_fsts: str = "",
|
||||||
|
rule_fars: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/fire_red_asr/index.html>`_
|
||||||
|
to download pre-trained models for different kinds of FireRedAsr models,
|
||||||
|
e.g., xs, large, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
Path to the encoder model.
|
||||||
|
decoder:
|
||||||
|
Path to the decoder model.
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
decoding_method:
|
||||||
|
Valid values: greedy_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
rule_fsts:
|
||||||
|
If not empty, it specifies fsts for inverse text normalization.
|
||||||
|
If there are multiple fsts, they are separated by a comma.
|
||||||
|
rule_fars:
|
||||||
|
If not empty, it specifies fst archives for inverse text normalization.
|
||||||
|
If there are multiple archives, they are separated by a comma.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
fire_red_asr=OfflineFireRedAsrModelConfig(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = FeatureExtractorConfig(
|
||||||
|
sampling_rate=16000,
|
||||||
|
feature_dim=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
rule_fsts=rule_fsts,
|
||||||
|
rule_fars=rule_fars,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
self.config = recognizer_config
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_moonshine(
|
def from_moonshine(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
Reference in New Issue
Block a user