Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -9,10 +9,11 @@ set(sources
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-decoder.cc
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
online-transducer-model-config.cc
|
||||
online-transducer-modified-beam-search-decoder.cc
|
||||
online-transducer-model.cc
|
||||
online-transducer-modified-beam-search-decoder.cc
|
||||
online-zipformer-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
parse-options.cc
|
||||
|
||||
@@ -12,9 +12,16 @@ namespace sherpa_onnx {
|
||||
|
||||
class Display {
|
||||
public:
|
||||
explicit Display(int32_t max_word_per_line = 60)
|
||||
: max_word_per_line_(max_word_per_line) {}
|
||||
|
||||
void Print(int32_t segment_id, const std::string &s) {
|
||||
#ifdef _MSC_VER
|
||||
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
|
||||
if (segment_id != -1) {
|
||||
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
|
||||
} else {
|
||||
fprintf(stderr, "%s\n", s.c_str());
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
if (last_segment_ == segment_id) {
|
||||
@@ -27,7 +34,9 @@ class Display {
|
||||
num_previous_lines_ = 0;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\r%d:", segment_id);
|
||||
if (segment_id != -1) {
|
||||
fprintf(stderr, "\r%d:", segment_id);
|
||||
}
|
||||
|
||||
int32_t i = 0;
|
||||
for (size_t n = 0; n < s.size();) {
|
||||
@@ -69,7 +78,7 @@ class Display {
|
||||
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
|
||||
|
||||
private:
|
||||
int32_t max_word_per_line_ = 60;
|
||||
int32_t max_word_per_line_;
|
||||
int32_t num_previous_lines_ = 0;
|
||||
int32_t last_segment_ = -1;
|
||||
};
|
||||
|
||||
@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
os << "feature_dim=" << feature_dim << ", ";
|
||||
os << "max_feature_vectors=" << max_feature_vectors << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
|
||||
// cache 100 seconds of feature frames, which is more than enough
|
||||
// for real needs
|
||||
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
||||
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace sherpa_onnx {
|
||||
struct FeatureExtractorConfig {
|
||||
float sampling_rate = 16000;
|
||||
int32_t feature_dim = 80;
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace sherpa_onnx {
|
||||
|
||||
struct Hypothesis {
|
||||
// The predicted tokens so far. Newly predicated tokens are appended.
|
||||
std::vector<int32_t> ys;
|
||||
std::vector<int64_t> ys;
|
||||
|
||||
// timestamps[i] contains the frame number after subsampling
|
||||
// on which ys[i] is decoded.
|
||||
@@ -30,7 +30,7 @@ struct Hypothesis {
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
|
||||
// If two Hypotheses have the same `Key`, then they contain
|
||||
|
||||
@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("decoding-mothod", &decoding_method,
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
}
|
||||
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
|
||||
os << "max_active_paths=" << max_active_paths << ",";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const {
|
||||
// reset result, neural network model state, and
|
||||
// the feature extractor state
|
||||
|
||||
// reset result
|
||||
// we keep the decoder_out
|
||||
decoder_->UpdateDecoderOut(&s->GetResult());
|
||||
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
||||
s->SetResult(decoder_->GetEmptyResult());
|
||||
s->GetResult().decoder_out = std::move(decoder_out);
|
||||
|
||||
// reset neural network model state
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
// reset feature extractor
|
||||
// Note: We only update counters. The underlying audio samples
|
||||
// are not discarded.
|
||||
s->Reset();
|
||||
}
|
||||
|
||||
|
||||
@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
|
||||
OnlineTransducerModelConfig model_config;
|
||||
EndpointConfig endpoint_config;
|
||||
bool enable_endpoint = true;
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
std::string decoding_method = "modified_beam_search";
|
||||
std::string decoding_method = "greedy_search";
|
||||
// now support modified_beam_search and greedy_search
|
||||
|
||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint)
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint) {}
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -22,18 +22,21 @@ class OnlineStream::Impl {
|
||||
|
||||
void InputFinished() { feat_extractor_.InputFinished(); }
|
||||
|
||||
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
|
||||
int32_t NumFramesReady() const {
|
||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
return feat_extractor_.IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
return feat_extractor_.GetFrames(frame_index, n);
|
||||
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
feat_extractor_.Reset();
|
||||
// we don't reset the feature extractor
|
||||
start_frame_index_ += num_processed_frames_;
|
||||
num_processed_frames_ = 0;
|
||||
}
|
||||
|
||||
@@ -41,7 +44,7 @@ class OnlineStream::Impl {
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||
|
||||
const OnlineTransducerDecoderResult &GetResult() const { return result_; }
|
||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||
|
||||
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
||||
|
||||
@@ -54,6 +57,7 @@ class OnlineStream::Impl {
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
|
||||
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class OnlineStream {
|
||||
int32_t &GetNumProcessedFrames();
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||
const OnlineTransducerDecoderResult &GetResult() const;
|
||||
OnlineTransducerDecoderResult &GetResult();
|
||||
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
60
sherpa-onnx/csrc/online-transducer-decoder.cc
Normal file
60
sherpa-onnx/csrc/online-transducer-decoder.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/csrc/online-transducer-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
|
||||
const OnlineTransducerDecoderResult &other)
|
||||
: OnlineTransducerDecoderResult() {
|
||||
*this = other;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
const OnlineTransducerDecoderResult &other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
tokens = other.tokens;
|
||||
num_trailing_blanks = other.num_trailing_blanks;
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
if (other.decoder_out) {
|
||||
decoder_out = Clone(allocator, &other.decoder_out);
|
||||
}
|
||||
|
||||
hyps = other.hyps;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
|
||||
OnlineTransducerDecoderResult &&other)
|
||||
: OnlineTransducerDecoderResult() {
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
OnlineTransducerDecoderResult &&other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
tokens = std::move(other.tokens);
|
||||
num_trailing_blanks = other.num_trailing_blanks;
|
||||
decoder_out = std::move(other.decoder_out);
|
||||
hyps = std::move(other.hyps);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
// Cache decoder_out for endpointing
|
||||
Ort::Value decoder_out;
|
||||
|
||||
// used only in modified beam_search
|
||||
Hypotheses hyps;
|
||||
|
||||
OnlineTransducerDecoderResult()
|
||||
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
|
||||
|
||||
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
|
||||
|
||||
OnlineTransducerDecoderResult &operator=(
|
||||
const OnlineTransducerDecoderResult &other);
|
||||
|
||||
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
|
||||
|
||||
OnlineTransducerDecoderResult &operator=(
|
||||
OnlineTransducerDecoderResult &&other);
|
||||
};
|
||||
|
||||
class OnlineTransducerDecoder {
|
||||
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
|
||||
*/
|
||||
virtual void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||
|
||||
// used for endpointing. We need to keep decoder_out after reset
|
||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -13,6 +13,43 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
float *dst = decoder_out->GetTensorMutableData<float>();
|
||||
for (const auto &r : results) {
|
||||
if (r.decoder_out) {
|
||||
const float *src = r.decoder_out.GetTensorData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
}
|
||||
dst += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateCachedDecoderOut(
|
||||
OrtAllocator *allocator, const Ort::Value *decoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *results) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
std::array<int64_t, 2> v_shape{1, shape[1]};
|
||||
|
||||
const float *src = decoder_out->GetTensorData<float>();
|
||||
for (auto &r : *results) {
|
||||
if (!r.decoder_out) {
|
||||
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
|
||||
v_shape.size());
|
||||
}
|
||||
|
||||
float *dst = r.decoder_out.GetTensorMutableData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
src += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
UseCachedDecoderOut(*result, &decoder_out);
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
Ort::Value cur_encoder_out =
|
||||
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
}
|
||||
|
||||
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -13,6 +13,29 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<int32_t> &hyps_num_split,
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
int32_t context_size, Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
float *dst = decoder_out->GetTensorMutableData<float>();
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
|
||||
if (num_hyps > 1 || !results[i].decoder_out) {
|
||||
dst += num_hyps * shape[1];
|
||||
continue;
|
||||
}
|
||||
|
||||
const float *src = results[i].decoder_out.GetTensorData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
dst += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResult r;
|
||||
std::vector<int32_t> blanks(context_size, blank_id);
|
||||
std::vector<int64_t> blanks(context_size, blank_id);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
r.hyps = std::move(blank_hyp);
|
||||
return r;
|
||||
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
if (t == 0) {
|
||||
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
|
||||
&decoder_out);
|
||||
}
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
(*result)[b].hyps = std::move(cur[b]);
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
(*result)[b].hyps = std::move(hyps);
|
||||
(*result)[b].tokens = std::move(best_hyp.ys);
|
||||
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
|
||||
OnlineTransducerDecoderResult *result) {
|
||||
if (result->tokens.size() == model_->ContextSize()) {
|
||||
result->decoder_out = Ort::Value{nullptr};
|
||||
return;
|
||||
}
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
|
||||
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
||||
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
int32_t max_active_paths_;
|
||||
|
||||
@@ -21,7 +21,7 @@ static void Handler(int sig) {
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-alsa \
|
||||
@@ -30,7 +30,10 @@ Usage:
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
device_name \
|
||||
[num_threads]
|
||||
[num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -79,6 +82,11 @@ as the device_name.
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
config.enable_endpoint = true;
|
||||
|
||||
config.endpoint_config.rule1.min_trailing_silence = 2.4;
|
||||
|
||||
@@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
|
||||
}
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
if (argc < 5 || argc > 6) {
|
||||
if (argc < 5 || argc > 7) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-microphone \
|
||||
@@ -44,7 +44,10 @@ Usage:
|
||||
/path/to/encoder.onnx\
|
||||
/path/to/decoder.onnx\
|
||||
/path/to/joiner.onnx\
|
||||
[num_threads]
|
||||
[num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -70,6 +73,11 @@ for a list of pre-trained models to download.
|
||||
config.model_config.num_threads = atoi(argv[5]);
|
||||
}
|
||||
|
||||
if (argc == 7) {
|
||||
config.decoding_method = argv[6];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
config.enable_endpoint = true;
|
||||
|
||||
config.endpoint_config.rule1.min_trailing_silence = 2.4;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx \
|
||||
@@ -22,7 +22,10 @@ Usage:
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/foo.wav [num_threads]
|
||||
/path/to/foo.wav [num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -45,9 +48,15 @@ for a list of pre-trained models to download.
|
||||
std::string wav_filename = argv[5];
|
||||
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7) {
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
@@ -98,6 +107,7 @@ for a list of pre-trained models to download.
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
|
||||
Reference in New Issue
Block a user