code refactoring and add CI (#11)

This commit is contained in:
Fangjun Kuang
2022-10-12 11:27:05 +08:00
committed by GitHub
parent d9b84d5526
commit 77ccd625b8
9 changed files with 267 additions and 121 deletions

View File

@@ -1,13 +1,8 @@
add_executable(online-fbank-test online-fbank-test.cc)
target_link_libraries(online-fbank-test kaldi-native-fbank-core)
include_directories(
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/
)
include_directories(
${KALDINATIVEIO}
)
include_directories(${CMAKE_SOURCE_DIR})
add_executable(sherpa-onnx main.cpp)
target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core)
target_link_libraries(sherpa-onnx
onnxruntime
kaldi-native-fbank-core
kaldi_native_io_core
)

View File

@@ -1,101 +1,99 @@
#include <vector>
#include <iostream>
#include <algorithm>
#include <time.h>
#include <math.h>
#include <fstream>
#include <iostream>
#include <math.h>
#include <time.h>
#include <vector>
#include "fbank_features.h"
#include "rnnt_beam_search.h"
#include "sherpa-onnx/csrc/fbank_features.h"
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
int main(int argc, char *argv[]) {
char *encoder_path = argv[1];
char *decoder_path = argv[2];
char *joiner_path = argv[3];
char *joiner_encoder_proj_path = argv[4];
char *joiner_decoder_proj_path = argv[5];
char *token_path = argv[6];
std::string search_method = argv[7];
char *filename = argv[8];
int main(int argc, char* argv[]) {
char* encoder_path = argv[1];
char* decoder_path = argv[2];
char* joiner_path = argv[3];
char* joiner_encoder_proj_path = argv[4];
char* joiner_decoder_proj_path = argv[5];
char* token_path = argv[6];
std::string search_method = argv[7];
char* filename = argv[8];
// General parameters
int numberOfThreads = 16;
// General parameters
int numberOfThreads = 16;
// Initialize fbanks
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.samp_freq = 16000;
opts.frame_opts.frame_shift_ms = 10.0f;
opts.frame_opts.frame_length_ms = 25.0f;
opts.mel_opts.num_bins = 80;
opts.frame_opts.window_type = "povey";
opts.frame_opts.snip_edges = false;
knf::OnlineFbank fbank(opts);
// Initialize fbanks
knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.samp_freq = 16000;
opts.frame_opts.frame_shift_ms = 10.0f;
opts.frame_opts.frame_length_ms = 25.0f;
opts.mel_opts.num_bins = 80;
opts.frame_opts.window_type = "povey";
opts.frame_opts.snip_edges = false;
knf::OnlineFbank fbank(opts);
// set session opts
// https://onnxruntime.ai/docs/performance/tune-performance.html
session_options.SetIntraOpNumThreads(numberOfThreads);
session_options.SetInterOpNumThreads(numberOfThreads);
session_options.SetGraphOptimizationLevel(
GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
session_options.SetLogSeverityLevel(4);
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
// set session opts
// https://onnxruntime.ai/docs/performance/tune-performance.html
session_options.SetIntraOpNumThreads(numberOfThreads);
session_options.SetInterOpNumThreads(numberOfThreads);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
session_options.SetLogSeverityLevel(4);
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
api.CreateTensorRTProviderOptions(&tensorrt_options);
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get());
api.CreateTensorRTProviderOptions(&tensorrt_options);
std::unique_ptr<OrtTensorRTProviderOptionsV2,
decltype(api.ReleaseTensorRTProviderOptions)>
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
// Define model
auto model = get_model(
encoder_path,
decoder_path,
joiner_path,
joiner_encoder_proj_path,
joiner_decoder_proj_path,
token_path
);
std::vector<std::string> filename_list {
filename
};
// Define model
auto model =
get_model(encoder_path, decoder_path, joiner_path,
joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
for (auto filename : filename_list){
std::cout << filename << std::endl;
auto samples = readWav(filename, true);
int numSamples = samples.NumCols();
std::vector<std::string> filename_list{filename};
auto features = ComputeFeatures(fbank, opts, samples);
for (auto filename : filename_list) {
std::cout << filename << std::endl;
auto samples = readWav(filename, true);
int numSamples = samples.NumCols();
auto tic = std::chrono::high_resolution_clock::now();
auto features = ComputeFeatures(fbank, opts, samples);
// # === Encoder Out === #
int num_frames = features.size() / opts.mel_opts.num_bins;
auto encoder_out = model.encoder_forward(features,
std::vector<int64_t> {num_frames},
std::vector<int64_t> {1, num_frames, 80},
std::vector<int64_t> {1},
memory_info);
auto tic = std::chrono::high_resolution_clock::now();
// # === Search === #
std::vector<std::vector<int32_t>> hyps;
if (search_method == "greedy")
hyps = GreedySearch(&model, &encoder_out);
else{
std::cout << "wrong search method!" << std::endl;
exit(0);
}
auto results = hyps2result(model.tokens_map, hyps);
// # === Encoder Out === #
int num_frames = features.size() / opts.mel_opts.num_bins;
auto encoder_out =
model.encoder_forward(features, std::vector<int64_t>{num_frames},
std::vector<int64_t>{1, num_frames, 80},
std::vector<int64_t>{1}, memory_info);
// # === Print Elapsed Time === #
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic);
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl;
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl;
print_hyps(hyps);
std::cout << results[0] << std::endl;
// # === Search === #
std::vector<std::vector<int32_t>> hyps;
if (search_method == "greedy")
hyps = GreedySearch(&model, &encoder_out);
else {
std::cout << "wrong search method!" << std::endl;
exit(0);
}
auto results = hyps2result(model.tokens_map, hyps);
return 0;
// # === Print Elapsed Time === #
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - tic);
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
<< std::endl;
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
<< std::endl;
print_hyps(hyps);
std::cout << results[0] << std::endl;
}
return 0;
}

View File

@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
memory_info);
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto logits = model->joiner_forward(cur_encoder_out,
projected_decoder_out_vector,
std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2},
std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim},
std::vector<int64_t> {1, projected_encoder_out_dim2},
std::vector<int64_t> {1, projected_decoder_out_dim},
memory_info);
Ort::Value &logits_tensor = logits[0];
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3];
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));

View File

@@ -1,5 +1,5 @@
#include <iostream>
#include <onnxruntime_cxx_api.h>
#include "onnxruntime_cxx_api.h"
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
const auto& api = Ort::GetApi();