diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 44e4ce35..8d527e90 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -5,6 +5,8 @@ #include #include // NOLINT +#include +#include #include #include @@ -14,6 +16,12 @@ #include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/wave-reader.h" +typedef struct { + std::unique_ptr online_stream; + float duration; + float elapsed_seconds; +} Stream; + int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( Usage: @@ -61,29 +69,26 @@ for a list of pre-trained models to download. sherpa_onnx::OnlineRecognizer recognizer(config); - float duration = 0; + std::vector ss; + + const auto begin = std::chrono::steady_clock::now(); + std::vector durations; + for (int32_t i = 1; i <= po.NumArgs(); ++i) { const std::string wav_filename = po.GetArg(i); int32_t sampling_rate = -1; bool is_ok = false; const std::vector samples = - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); return -1; } - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); const float duration = samples.size() / static_cast(sampling_rate); - fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); - fprintf(stderr, "wav duration (s): %.3f\n", duration); - - fprintf(stderr, "Started\n"); - const auto begin = std::chrono::steady_clock::now(); - auto s = recognizer.CreateStream(); s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); @@ -94,33 +99,46 @@ for a list of pre-trained models to download. // Call InputFinished() to indicate that no audio samples are available s->InputFinished(); - - while (recognizer.IsReady(s.get())) { - recognizer.DecodeStream(s.get()); - } - - const std::string text = recognizer.GetResult(s.get()).AsJsonString(); - - const auto end = std::chrono::steady_clock::now(); - const float elapsed_seconds = - std::chrono::duration_cast(end - begin) - .count() / 1000.; - - fprintf(stderr, "Done!\n"); - fprintf(stderr, - "Recognition result for %s:\n%s\n", - wav_filename.c_str(), text.c_str()); - fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); - fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); - if (config.decoding_method == "modified_beam_search") { - fprintf(stderr, "max active paths: %d\n", config.max_active_paths); - } - - fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); - const float rtf = elapsed_seconds / duration; - fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", - elapsed_seconds, duration, rtf); + ss.push_back({ std::move(s), duration, 0 }); } + std::vector ready_streams; + for (;;) { + ready_streams.clear(); + for (auto &s : ss) { + const auto p_ss = s.online_stream.get(); + if (recognizer.IsReady(p_ss)) { + ready_streams.push_back(p_ss); + } else if (s.elapsed_seconds == 0) { + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / 1000.; + s.elapsed_seconds = elapsed_seconds; + } + } + + if (ready_streams.empty()) { + break; + } + + recognizer.DecodeStreams(ready_streams.data(), ready_streams.size()); + } + + std::ostringstream os; + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + const auto &s = ss[i - 1]; + const float rtf = s.elapsed_seconds / s.duration; + + os << po.GetArg(i) << "\n"; + os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds + << ", Real time factor (RTF): " << rtf << "\n"; + const auto r = recognizer.GetResult(s.online_stream.get()); + os << r.text << "\n"; + os << r.AsJsonString() << "\n\n"; + } + + std::cerr << os.str(); + return 0; }