Use static libraries for MFC examples (#210)

This commit is contained in:
Fangjun Kuang
2023-07-13 14:52:43 +08:00
committed by GitHub
parent 10f132cfd6
commit bebc1f1398
18 changed files with 380 additions and 156 deletions

View File

@@ -9,11 +9,11 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include <numeric>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
@@ -78,7 +78,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
#endif
void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
size_t model_data_length) {
size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
@@ -130,7 +130,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
}
void OnlineZipformer2TransducerModel::InitDecoder(void *model_data,
size_t model_data_length) {
size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
@@ -155,7 +155,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data,
}
void OnlineZipformer2TransducerModel::InitJoiner(void *model_data,
size_t model_data_length) {
size_t model_data_length) {
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
model_data_length, sess_opts_);
@@ -252,7 +252,8 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::StackStates(
std::vector<std::vector<Ort::Value>>
OnlineZipformer2TransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0);
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
assert(states.size() == m * 6 + 2);
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
@@ -332,10 +333,12 @@ OnlineZipformer2TransducerModel::UnStackStates(
return ans;
}
std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates() {
std::vector<Ort::Value>
OnlineZipformer2TransducerModel::GetEncoderInitStates() {
std::vector<Ort::Value> ans;
int32_t n = static_cast<int32_t>(encoder_dims_.size());
int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0);
int32_t m = std::accumulate(num_encoder_layers_.begin(),
num_encoder_layers_.end(), 0);
ans.reserve(m * 6 + 2);
for (int32_t i = 0; i != n; ++i) {
@@ -354,7 +357,8 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates()
}
{
std::array<int64_t, 4> s{1, 1, left_context_len_[i], nonlin_attn_head_dim};
std::array<int64_t, 4> s{1, 1, left_context_len_[i],
nonlin_attn_head_dim};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
@@ -378,7 +382,8 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates()
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2};
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
@@ -386,7 +391,8 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates()
}
{
std::array<int64_t, 3> s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2};
std::array<int64_t, 3> s{1, encoder_dims_[i],
cnn_module_kernels_[i] / 2};
auto v =
Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
@@ -413,8 +419,8 @@ std::vector<Ort::Value> OnlineZipformer2TransducerModel::GetEncoderInitStates()
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineZipformer2TransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> states,
Ort::Value /* processed_frames */) {
std::vector<Ort::Value> states,
Ort::Value /* processed_frames */) {
std::vector<Ort::Value> encoder_inputs;
encoder_inputs.reserve(1 + states.size());
@@ -446,7 +452,7 @@ Ort::Value OnlineZipformer2TransducerModel::RunDecoder(
}
Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out,
Ort::Value decoder_out) {
Ort::Value decoder_out) {
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
std::move(decoder_out)};
auto logit =