Support CoreML for macOS (#151)
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -24,6 +23,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
@@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -185,7 +179,7 @@ std::vector<std::vector<Ort::Value>>
|
||||
OnlineConformerTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
const int32_t batch_size =
|
||||
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
assert(states.size() == 2);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
@@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
|
||||
// https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
|
||||
// for details
|
||||
constexpr int32_t kBatchSize = 1;
|
||||
std::array<int64_t, 4> h_shape{
|
||||
num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
|
||||
std::array<int64_t, 4> h_shape{num_encoder_layers_, left_context_, kBatchSize,
|
||||
encoder_dim_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
|
||||
@@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) {
|
||||
std::array<Ort::Value, 4> encoder_inputs = {
|
||||
std::move(features),
|
||||
std::move(states[0]),
|
||||
std::move(states[1]),
|
||||
std::move(features), std::move(states[0]), std::move(states[1]),
|
||||
std::move(processed_frames)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
|
||||
Reference in New Issue
Block a user