Support CoreML for macOS (#151)
This commit is contained in:
@@ -1 +0,0 @@
|
||||
exclude_files=.*
|
||||
@@ -35,8 +35,8 @@ set(sources
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-lm.cc
|
||||
online-lm-config.cc
|
||||
online-lm.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-rnn-lm.cc
|
||||
@@ -48,9 +48,11 @@ set(sources
|
||||
online-transducer-modified-beam-search-decoder.cc
|
||||
online-zipformer-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
session.cc
|
||||
packed-sequence.cc
|
||||
pad-sequence.cc
|
||||
parse-options.cc
|
||||
provider.cc
|
||||
resample.cc
|
||||
slice.cc
|
||||
stack.cc
|
||||
|
||||
@@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
@@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -20,18 +20,21 @@ struct OfflineModelConfig {
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug)
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
@@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config_.num_threads);
|
||||
|
||||
Init();
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl {
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config_.num_threads);
|
||||
|
||||
Init();
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl {
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -24,6 +23,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
@@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -185,7 +179,7 @@ std::vector<std::vector<Ort::Value>>
|
||||
OnlineConformerTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
const int32_t batch_size =
|
||||
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
assert(states.size() == 2);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
@@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
|
||||
// https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
|
||||
// for details
|
||||
constexpr int32_t kBatchSize = 1;
|
||||
std::array<int64_t, 4> h_shape{
|
||||
num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
|
||||
std::array<int64_t, 4> h_shape{num_encoder_layers_, left_context_, kBatchSize,
|
||||
encoder_dim_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
|
||||
@@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) {
|
||||
std::array<Ort::Value, 4> encoder_inputs = {
|
||||
std::move(features),
|
||||
std::move(states[0]),
|
||||
std::move(states[1]),
|
||||
std::move(features), std::move(states[0]), std::move(states[1]),
|
||||
std::move(processed_frames)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
@@ -140,8 +139,8 @@ class OnlineRecognizer::Impl {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported decoding method: %s\n",
|
||||
config.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -160,8 +159,8 @@ class OnlineRecognizer::Impl {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported decoding method: %s\n",
|
||||
config.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -216,19 +215,16 @@ class OnlineRecognizer::Impl {
|
||||
x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> processed_frames_shape{
|
||||
static_cast<int64_t>(all_processed_frames.size())};
|
||||
static_cast<int64_t>(all_processed_frames.size())};
|
||||
|
||||
Ort::Value processed_frames = Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
all_processed_frames.data(),
|
||||
all_processed_frames.size(),
|
||||
processed_frames_shape.data(),
|
||||
processed_frames_shape.size());
|
||||
memory_info, all_processed_frames.data(), all_processed_frames.size(),
|
||||
processed_frames_shape.data(), processed_frames_shape.size());
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(
|
||||
std::move(x), std::move(states), std::move(processed_frames));
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
|
||||
|
||||
@@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig {
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OnlineTransducerModelConfig() = default;
|
||||
OnlineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
bool debug)
|
||||
bool debug, const std::string &provider)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
@@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
|
||||
29
sherpa-onnx/csrc/provider.cc
Normal file
29
sherpa-onnx/csrc/provider.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/provider.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/provider.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
Provider StringToProvider(std::string s) {
|
||||
std::transform(s.cbegin(), s.cend(), s.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
if (s == "cpu") {
|
||||
return Provider::kCPU;
|
||||
} else if (s == "cuda") {
|
||||
return Provider::kCUDA;
|
||||
} else if (s == "coreml") {
|
||||
return Provider::kCoreML;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
||||
return Provider::kCPU;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
31
sherpa-onnx/csrc/provider.h
Normal file
31
sherpa-onnx/csrc/provider.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/provider.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_PROVIDER_H_
|
||||
#define SHERPA_ONNX_CSRC_PROVIDER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// Please refer to
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
|
||||
// for a list of available providers
|
||||
enum class Provider {
|
||||
kCPU = 0, // CPUExecutionProvider
|
||||
kCUDA = 1, // CUDAExecutionProvider
|
||||
kCoreML = 2, // CoreMLExecutionProvider
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert a string to an enum.
|
||||
*
|
||||
* @param s We will convert it to lowercase before comparing.
|
||||
* @return Return an instance of Provider.
|
||||
*/
|
||||
Provider StringToProvider(std::string s);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_PROVIDER_H_
|
||||
60
sherpa-onnx/csrc/session.cc
Normal file
60
sherpa-onnx/csrc/session.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/csrc/session.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/provider.h"
|
||||
#if defined(__APPLE__)
|
||||
#include "coreml_provider_factory.h" // NOLINT
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
std::string provider_str) {
|
||||
Provider p = StringToProvider(std::move(provider_str));
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
sess_opts.SetIntraOpNumThreads(num_threads);
|
||||
sess_opts.SetInterOpNumThreads(num_threads);
|
||||
|
||||
switch (p) {
|
||||
case Provider::kCPU:
|
||||
break; // nothing to do for the CPU provider
|
||||
case Provider::kCUDA: {
|
||||
OrtCUDAProviderOptions options;
|
||||
options.device_id = 0;
|
||||
// set more options on need
|
||||
sess_opts.AppendExecutionProvider_CUDA(options);
|
||||
break;
|
||||
}
|
||||
case Provider::kCoreML: {
|
||||
#if defined(__APPLE__)
|
||||
uint32_t coreml_flags = 0;
|
||||
(void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts,
|
||||
coreml_flags);
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!");
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return sess_opts;
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlineTransducerModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
21
sherpa-onnx/csrc/session.h
Normal file
21
sherpa-onnx/csrc/session.h
Normal file
@@ -0,0 +1,21 @@
|
||||
// sherpa-onnx/csrc/session.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SESSION_H_
|
||||
#define SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
@@ -36,7 +35,7 @@ template <typename T /*=float*/>
|
||||
Ort::Value Stack(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim) {
|
||||
std::vector<int64_t> v0_shape =
|
||||
values[0]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
values[0]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
|
||||
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
@@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator,
|
||||
ans_shape.reserve(v0_shape.size() + 1);
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
|
||||
ans_shape.push_back(values.size());
|
||||
ans_shape.insert(
|
||||
ans_shape.end(),
|
||||
v0_shape.data() + dim,
|
||||
v0_shape.data() + v0_shape.size());
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data() + dim,
|
||||
v0_shape.data() + v0_shape.size());
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(
|
||||
std::accumulate(v0_shape.begin() + dim,
|
||||
v0_shape.end(), 1,
|
||||
std::multiplies<int64_t>()));
|
||||
auto trailing_size = static_cast<int32_t>(std::accumulate(
|
||||
v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(
|
||||
allocator, ans_shape.data(), ans_shape.size());
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
@@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator,
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Stack<float>(
|
||||
OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
template Ort::Value Stack<float>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
template Ort::Value Stack<int64_t>(
|
||||
OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
OrtAllocator *allocator, const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def(py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const std::string &, int32_t, bool>(),
|
||||
const std::string &, int32_t, bool, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false)
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
|
||||
using PyClass = OnlineTransducerModelConfig;
|
||||
py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &, int32_t, bool>(),
|
||||
const std::string &, const std::string &, int32_t, bool,
|
||||
const std::string &>(),
|
||||
py::arg("encoder_filename"), py::arg("decoder_filename"),
|
||||
py::arg("joiner_filename"), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false)
|
||||
py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu")
|
||||
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
|
||||
.def_readwrite("decoder_filename", &PyClass::decoder_filename)
|
||||
.def_readwrite("joiner_filename", &PyClass::joiner_filename)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class OfflineRecognizer(object):
|
||||
feature_dim: int = 80,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -70,6 +71,8 @@ class OfflineRecognizer(object):
|
||||
Support only greedy_search for now.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
@@ -81,6 +84,7 @@ class OfflineRecognizer(object):
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
|
||||
@@ -39,6 +39,7 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -86,6 +87,8 @@ class OnlineRecognizer(object):
|
||||
max_active_paths:
|
||||
Use only when decoding_method is modified_beam_search. It specifies
|
||||
the maximum number of active paths during beam search.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -100,6 +103,7 @@ class OnlineRecognizer(object):
|
||||
joiner_filename=joiner,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
|
||||
Reference in New Issue
Block a user