102 lines
3.5 KiB
C++
102 lines
3.5 KiB
C++
#include <vector>
|
|
#include <iostream>
|
|
#include <algorithm>
|
|
#include <time.h>
|
|
#include <math.h>
|
|
#include <fstream>
|
|
|
|
#include "fbank_features.h"
|
|
#include "rnnt_beam_search.h"
|
|
|
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
char* encoder_path = argv[1];
|
|
char* decoder_path = argv[2];
|
|
char* joiner_path = argv[3];
|
|
char* joiner_encoder_proj_path = argv[4];
|
|
char* joiner_decoder_proj_path = argv[5];
|
|
char* token_path = argv[6];
|
|
std::string search_method = argv[7];
|
|
char* filename = argv[8];
|
|
|
|
// General parameters
|
|
int numberOfThreads = 16;
|
|
|
|
// Initialize fbanks
|
|
knf::FbankOptions opts;
|
|
opts.frame_opts.dither = 0;
|
|
opts.frame_opts.samp_freq = 16000;
|
|
opts.frame_opts.frame_shift_ms = 10.0f;
|
|
opts.frame_opts.frame_length_ms = 25.0f;
|
|
opts.mel_opts.num_bins = 80;
|
|
opts.frame_opts.window_type = "povey";
|
|
opts.frame_opts.snip_edges = false;
|
|
knf::OnlineFbank fbank(opts);
|
|
|
|
// set session opts
|
|
// https://onnxruntime.ai/docs/performance/tune-performance.html
|
|
session_options.SetIntraOpNumThreads(numberOfThreads);
|
|
session_options.SetInterOpNumThreads(numberOfThreads);
|
|
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
|
session_options.SetLogSeverityLevel(4);
|
|
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
|
|
|
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
|
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
|
|
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get());
|
|
|
|
// Define model
|
|
auto model = get_model(
|
|
encoder_path,
|
|
decoder_path,
|
|
joiner_path,
|
|
joiner_encoder_proj_path,
|
|
joiner_decoder_proj_path,
|
|
token_path
|
|
);
|
|
|
|
std::vector<std::string> filename_list {
|
|
filename
|
|
};
|
|
|
|
for (auto filename : filename_list){
|
|
std::cout << filename << std::endl;
|
|
auto samples = readWav(filename, true);
|
|
int numSamples = samples.NumCols();
|
|
|
|
auto features = ComputeFeatures(fbank, opts, samples);
|
|
|
|
auto tic = std::chrono::high_resolution_clock::now();
|
|
|
|
// # === Encoder Out === #
|
|
int num_frames = features.size() / opts.mel_opts.num_bins;
|
|
auto encoder_out = model.encoder_forward(features,
|
|
std::vector<int64_t> {num_frames},
|
|
std::vector<int64_t> {1, num_frames, 80},
|
|
std::vector<int64_t> {1},
|
|
memory_info);
|
|
|
|
// # === Search === #
|
|
std::vector<std::vector<int32_t>> hyps;
|
|
if (search_method == "greedy")
|
|
hyps = GreedySearch(&model, &encoder_out);
|
|
else{
|
|
std::cout << "wrong search method!" << std::endl;
|
|
exit(0);
|
|
}
|
|
auto results = hyps2result(model.tokens_map, hyps);
|
|
|
|
// # === Print Elapsed Time === #
|
|
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic);
|
|
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl;
|
|
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl;
|
|
|
|
print_hyps(hyps);
|
|
std::cout << results[0] << std::endl;
|
|
}
|
|
|
|
return 0;
|
|
}
|