#include #include #include #include #include #include #include "fbank_features.h" #include "rnnt_beam_search.h" #include "kaldi-native-fbank/csrc/online-feature.h" int main(int argc, char* argv[]) { char* encoder_path = argv[1]; char* decoder_path = argv[2]; char* joiner_path = argv[3]; char* joiner_encoder_proj_path = argv[4]; char* joiner_decoder_proj_path = argv[5]; char* token_path = argv[6]; std::string search_method = argv[7]; char* filename = argv[8]; // General parameters int numberOfThreads = 16; // Initialize fbanks knf::FbankOptions opts; opts.frame_opts.dither = 0; opts.frame_opts.samp_freq = 16000; opts.frame_opts.frame_shift_ms = 10.0f; opts.frame_opts.frame_length_ms = 25.0f; opts.mel_opts.num_bins = 80; opts.frame_opts.window_type = "povey"; opts.frame_opts.snip_edges = false; knf::OnlineFbank fbank(opts); // set session opts // https://onnxruntime.ai/docs/performance/tune-performance.html session_options.SetIntraOpNumThreads(numberOfThreads); session_options.SetInterOpNumThreads(numberOfThreads); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); session_options.SetLogSeverityLevel(4); session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); api.CreateTensorRTProviderOptions(&tensorrt_options); std::unique_ptr rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), rel_trt_options.get()); // Define model auto model = get_model( encoder_path, decoder_path, joiner_path, joiner_encoder_proj_path, joiner_decoder_proj_path, token_path ); std::vector filename_list { filename }; for (auto filename : filename_list){ std::cout << filename << std::endl; auto samples = readWav(filename, true); int numSamples = samples.NumCols(); auto features = ComputeFeatures(fbank, opts, samples); auto tic = std::chrono::high_resolution_clock::now(); // # === Encoder Out === # int num_frames = features.size() / opts.mel_opts.num_bins; auto encoder_out = model.encoder_forward(features, std::vector {num_frames}, std::vector {1, num_frames, 80}, std::vector {1}, memory_info); // # === Search === # std::vector> hyps; if (search_method == "greedy") hyps = GreedySearch(&model, &encoder_out); else{ std::cout << "wrong search method!" << std::endl; exit(0); } auto results = hyps2result(model.tokens_map, hyps); // # === Print Elapsed Time === # auto elapsed = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tic); std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl; std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl; print_hyps(hyps); std::cout << results[0] << std::endl; } return 0; }