add batch processing to sherpa-onnx (#166)
This commit is contained in:
@@ -5,6 +5,8 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <chrono> // NOLINT
|
#include <chrono> // NOLINT
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -14,6 +16,12 @@
|
|||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
|
||||||
|
float duration;
|
||||||
|
float elapsed_seconds;
|
||||||
|
} Stream;
|
||||||
|
|
||||||
int main(int32_t argc, char *argv[]) {
|
int main(int32_t argc, char *argv[]) {
|
||||||
const char *kUsageMessage = R"usage(
|
const char *kUsageMessage = R"usage(
|
||||||
Usage:
|
Usage:
|
||||||
@@ -61,7 +69,11 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||||
|
|
||||||
float duration = 0;
|
std::vector<Stream> ss;
|
||||||
|
|
||||||
|
const auto begin = std::chrono::steady_clock::now();
|
||||||
|
std::vector<float> durations;
|
||||||
|
|
||||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||||
const std::string wav_filename = po.GetArg(i);
|
const std::string wav_filename = po.GetArg(i);
|
||||||
int32_t sampling_rate = -1;
|
int32_t sampling_rate = -1;
|
||||||
@@ -74,16 +86,9 @@ for a list of pre-trained models to download.
|
|||||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
|
|
||||||
|
|
||||||
const float duration = samples.size() / static_cast<float>(sampling_rate);
|
const float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||||
|
|
||||||
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
|
|
||||||
fprintf(stderr, "wav duration (s): %.3f\n", duration);
|
|
||||||
|
|
||||||
fprintf(stderr, "Started\n");
|
|
||||||
const auto begin = std::chrono::steady_clock::now();
|
|
||||||
|
|
||||||
auto s = recognizer.CreateStream();
|
auto s = recognizer.CreateStream();
|
||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
@@ -94,33 +99,46 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
// Call InputFinished() to indicate that no audio samples are available
|
// Call InputFinished() to indicate that no audio samples are available
|
||||||
s->InputFinished();
|
s->InputFinished();
|
||||||
|
ss.push_back({ std::move(s), duration, 0 });
|
||||||
while (recognizer.IsReady(s.get())) {
|
|
||||||
recognizer.DecodeStream(s.get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string text = recognizer.GetResult(s.get()).AsJsonString();
|
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
||||||
|
for (;;) {
|
||||||
|
ready_streams.clear();
|
||||||
|
for (auto &s : ss) {
|
||||||
|
const auto p_ss = s.online_stream.get();
|
||||||
|
if (recognizer.IsReady(p_ss)) {
|
||||||
|
ready_streams.push_back(p_ss);
|
||||||
|
} else if (s.elapsed_seconds == 0) {
|
||||||
const auto end = std::chrono::steady_clock::now();
|
const auto end = std::chrono::steady_clock::now();
|
||||||
const float elapsed_seconds =
|
const float elapsed_seconds =
|
||||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||||
.count() / 1000.;
|
.count() / 1000.;
|
||||||
|
s.elapsed_seconds = elapsed_seconds;
|
||||||
fprintf(stderr, "Done!\n");
|
}
|
||||||
fprintf(stderr,
|
|
||||||
"Recognition result for %s:\n%s\n",
|
|
||||||
wav_filename.c_str(), text.c_str());
|
|
||||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
|
||||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
|
||||||
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
if (ready_streams.empty()) {
|
||||||
const float rtf = elapsed_seconds / duration;
|
break;
|
||||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
|
||||||
elapsed_seconds, duration, rtf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recognizer.DecodeStreams(ready_streams.data(), ready_streams.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream os;
|
||||||
|
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||||
|
const auto &s = ss[i - 1];
|
||||||
|
const float rtf = s.elapsed_seconds / s.duration;
|
||||||
|
|
||||||
|
os << po.GetArg(i) << "\n";
|
||||||
|
os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
|
||||||
|
<< ", Real time factor (RTF): " << rtf << "\n";
|
||||||
|
const auto r = recognizer.GetResult(s.online_stream.get());
|
||||||
|
os << r.text << "\n";
|
||||||
|
os << r.AsJsonString() << "\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cerr << os.str();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user