Add C++ runtime for *streaming* faster conformer transducer from NeMo. (#889)
Co-authored-by: sangeet2020 <15uec053@gmail.com>
This commit is contained in:
@@ -7,13 +7,28 @@
|
||||
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
const OnlineRecognizerConfig &config) {
|
||||
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
|
||||
auto decoder_model = ReadFile(config.model_config.transducer.decoder);
|
||||
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
size_t node_count = sess->GetOutputCount();
|
||||
|
||||
if (node_count == 1) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
|
||||
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.model_config.paraformer.encoder.empty()) {
|
||||
@@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
|
||||
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
|
||||
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
|
||||
|
||||
size_t node_count = sess->GetOutputCount();
|
||||
|
||||
if (node_count == 1) {
|
||||
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
|
||||
} else {
|
||||
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.model_config.paraformer.encoder.empty()) {
|
||||
|
||||
Reference in New Issue
Block a user