Support CoreML for macOS (#151)

This commit is contained in:
Fangjun Kuang
2023-05-12 15:57:44 +08:00
committed by GitHub
parent de1880948b
commit cea718e3d8
22 changed files with 216 additions and 87 deletions

View File

@@ -9,7 +9,6 @@
#include <algorithm>
#include <memory>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
@@ -24,6 +23,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/unbind.h"
@@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
sess_opts_(GetSessionOptions(config)),
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(config.encoder_filename);
InitEncoder(buf.data(), buf.size());
@@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_{},
sess_opts_(GetSessionOptions(config)),
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config.num_threads);
sess_opts_.SetInterOpNumThreads(config.num_threads);
{
auto buf = ReadFile(mgr, config.encoder_filename);
InitEncoder(buf.data(), buf.size());
@@ -185,7 +179,7 @@ std::vector<std::vector<Ort::Value>>
OnlineConformerTransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
const int32_t batch_size =
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
assert(states.size() == 2);
std::vector<std::vector<Ort::Value>> ans(batch_size);
@@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
// https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
// for details
constexpr int32_t kBatchSize = 1;
std::array<int64_t, 4> h_shape{
num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
std::array<int64_t, 4> h_shape{num_encoder_layers_, left_context_, kBatchSize,
encoder_dim_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
@@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value processed_frames) {
std::array<Ort::Value, 4> encoder_inputs = {
std::move(features),
std::move(states[0]),
std::move(states[1]),
std::move(features), std::move(states[0]), std::move(states[1]),
std::move(processed_frames)};
auto encoder_out = encoder_sess_->Run(