Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
|
||||
struct SherpaOnnxOnlineRecognizer {
|
||||
@@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
|
||||
: impl(std::move(p)) {}
|
||||
};
|
||||
|
||||
struct SherpaOnnxDisplay {
|
||||
std::unique_ptr<sherpa_onnx::Display> impl;
|
||||
};
|
||||
|
||||
SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
|
||||
const SherpaOnnxOnlineRecognizerConfig *config) {
|
||||
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
|
||||
@@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
|
||||
recognizer_config.model_config.num_threads = config->model_config.num_threads;
|
||||
recognizer_config.model_config.debug = config->model_config.debug;
|
||||
|
||||
recognizer_config.decoding_method = config->decoding_method;
|
||||
recognizer_config.max_active_paths = config->max_active_paths;
|
||||
|
||||
recognizer_config.enable_endpoint = config->enable_endpoint;
|
||||
|
||||
recognizer_config.endpoint_config.rule1.min_trailing_silence =
|
||||
@@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
|
||||
SherpaOnnxOnlineStream *stream) {
|
||||
return recognizer->impl->IsEndpoint(stream->impl.get());
|
||||
}
|
||||
|
||||
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
|
||||
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
|
||||
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
|
||||
return ans;
|
||||
}
|
||||
|
||||
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
|
||||
|
||||
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
|
||||
display->impl->Print(idx, s);
|
||||
}
|
||||
|
||||
@@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
|
||||
SherpaOnnxFeatureConfig feat_config;
|
||||
SherpaOnnxOnlineTransducerModelConfig model_config;
|
||||
|
||||
/// Possible values are: greedy_search, modified_beam_search
|
||||
const char *decoding_method;
|
||||
|
||||
/// Used only when decoding_method is modified_beam_search
|
||||
/// Example value: 4
|
||||
int32_t max_active_paths;
|
||||
|
||||
/// 0 to disable endpoint detection.
|
||||
/// A non-zero value to enable endpoint detection.
|
||||
int32_t enable_endpoint;
|
||||
@@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
|
||||
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
|
||||
SherpaOnnxOnlineStream *stream);
|
||||
|
||||
// for displaying results on Linux/macOS.
|
||||
typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
|
||||
|
||||
/// Create a display object. Must be freed using DestroyDisplay to avoid
|
||||
/// memory leak.
|
||||
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
|
||||
|
||||
void DestroyDisplay(SherpaOnnxDisplay *display);
|
||||
|
||||
/// Print the result.
|
||||
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif
|
||||
|
||||
@@ -9,10 +9,11 @@ set(sources
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-decoder.cc
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
online-transducer-model-config.cc
|
||||
online-transducer-modified-beam-search-decoder.cc
|
||||
online-transducer-model.cc
|
||||
online-transducer-modified-beam-search-decoder.cc
|
||||
online-zipformer-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
parse-options.cc
|
||||
|
||||
@@ -12,9 +12,16 @@ namespace sherpa_onnx {
|
||||
|
||||
class Display {
|
||||
public:
|
||||
explicit Display(int32_t max_word_per_line = 60)
|
||||
: max_word_per_line_(max_word_per_line) {}
|
||||
|
||||
void Print(int32_t segment_id, const std::string &s) {
|
||||
#ifdef _MSC_VER
|
||||
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
|
||||
if (segment_id != -1) {
|
||||
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
|
||||
} else {
|
||||
fprintf(stderr, "%s\n", s.c_str());
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
if (last_segment_ == segment_id) {
|
||||
@@ -27,7 +34,9 @@ class Display {
|
||||
num_previous_lines_ = 0;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\r%d:", segment_id);
|
||||
if (segment_id != -1) {
|
||||
fprintf(stderr, "\r%d:", segment_id);
|
||||
}
|
||||
|
||||
int32_t i = 0;
|
||||
for (size_t n = 0; n < s.size();) {
|
||||
@@ -69,7 +78,7 @@ class Display {
|
||||
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
|
||||
|
||||
private:
|
||||
int32_t max_word_per_line_ = 60;
|
||||
int32_t max_word_per_line_;
|
||||
int32_t num_previous_lines_ = 0;
|
||||
int32_t last_segment_ = -1;
|
||||
};
|
||||
|
||||
@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
os << "feature_dim=" << feature_dim << ", ";
|
||||
os << "max_feature_vectors=" << max_feature_vectors << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
|
||||
// cache 100 seconds of feature frames, which is more than enough
|
||||
// for real needs
|
||||
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
||||
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace sherpa_onnx {
|
||||
struct FeatureExtractorConfig {
|
||||
float sampling_rate = 16000;
|
||||
int32_t feature_dim = 80;
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace sherpa_onnx {
|
||||
|
||||
struct Hypothesis {
|
||||
// The predicted tokens so far. Newly predicated tokens are appended.
|
||||
std::vector<int32_t> ys;
|
||||
std::vector<int64_t> ys;
|
||||
|
||||
// timestamps[i] contains the frame number after subsampling
|
||||
// on which ys[i] is decoded.
|
||||
@@ -30,7 +30,7 @@ struct Hypothesis {
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
|
||||
// If two Hypotheses have the same `Key`, then they contain
|
||||
|
||||
@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("decoding-mothod", &decoding_method,
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
}
|
||||
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
|
||||
os << "max_active_paths=" << max_active_paths << ",";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const {
|
||||
// reset result, neural network model state, and
|
||||
// the feature extractor state
|
||||
|
||||
// reset result
|
||||
// we keep the decoder_out
|
||||
decoder_->UpdateDecoderOut(&s->GetResult());
|
||||
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
|
||||
s->SetResult(decoder_->GetEmptyResult());
|
||||
s->GetResult().decoder_out = std::move(decoder_out);
|
||||
|
||||
// reset neural network model state
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
// reset feature extractor
|
||||
// Note: We only update counters. The underlying audio samples
|
||||
// are not discarded.
|
||||
s->Reset();
|
||||
}
|
||||
|
||||
|
||||
@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
|
||||
OnlineTransducerModelConfig model_config;
|
||||
EndpointConfig endpoint_config;
|
||||
bool enable_endpoint = true;
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
std::string decoding_method = "modified_beam_search";
|
||||
std::string decoding_method = "greedy_search";
|
||||
// now support modified_beam_search and greedy_search
|
||||
|
||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint)
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint) {}
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -22,18 +22,21 @@ class OnlineStream::Impl {
|
||||
|
||||
void InputFinished() { feat_extractor_.InputFinished(); }
|
||||
|
||||
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
|
||||
int32_t NumFramesReady() const {
|
||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
return feat_extractor_.IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
return feat_extractor_.GetFrames(frame_index, n);
|
||||
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
feat_extractor_.Reset();
|
||||
// we don't reset the feature extractor
|
||||
start_frame_index_ += num_processed_frames_;
|
||||
num_processed_frames_ = 0;
|
||||
}
|
||||
|
||||
@@ -41,7 +44,7 @@ class OnlineStream::Impl {
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||
|
||||
const OnlineTransducerDecoderResult &GetResult() const { return result_; }
|
||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||
|
||||
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
||||
|
||||
@@ -54,6 +57,7 @@ class OnlineStream::Impl {
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
|
||||
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class OnlineStream {
|
||||
int32_t &GetNumProcessedFrames();
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||
const OnlineTransducerDecoderResult &GetResult() const;
|
||||
OnlineTransducerDecoderResult &GetResult();
|
||||
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
60
sherpa-onnx/csrc/online-transducer-decoder.cc
Normal file
60
sherpa-onnx/csrc/online-transducer-decoder.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/csrc/online-transducer-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
|
||||
const OnlineTransducerDecoderResult &other)
|
||||
: OnlineTransducerDecoderResult() {
|
||||
*this = other;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
const OnlineTransducerDecoderResult &other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
tokens = other.tokens;
|
||||
num_trailing_blanks = other.num_trailing_blanks;
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
if (other.decoder_out) {
|
||||
decoder_out = Clone(allocator, &other.decoder_out);
|
||||
}
|
||||
|
||||
hyps = other.hyps;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
|
||||
OnlineTransducerDecoderResult &&other)
|
||||
: OnlineTransducerDecoderResult() {
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
OnlineTransducerDecoderResult &&other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
tokens = std::move(other.tokens);
|
||||
num_trailing_blanks = other.num_trailing_blanks;
|
||||
decoder_out = std::move(other.decoder_out);
|
||||
hyps = std::move(other.hyps);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
// Cache decoder_out for endpointing
|
||||
Ort::Value decoder_out;
|
||||
|
||||
// used only in modified beam_search
|
||||
Hypotheses hyps;
|
||||
|
||||
OnlineTransducerDecoderResult()
|
||||
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
|
||||
|
||||
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
|
||||
|
||||
OnlineTransducerDecoderResult &operator=(
|
||||
const OnlineTransducerDecoderResult &other);
|
||||
|
||||
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
|
||||
|
||||
OnlineTransducerDecoderResult &operator=(
|
||||
OnlineTransducerDecoderResult &&other);
|
||||
};
|
||||
|
||||
class OnlineTransducerDecoder {
|
||||
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
|
||||
*/
|
||||
virtual void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||
|
||||
// used for endpointing. We need to keep decoder_out after reset
|
||||
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -13,6 +13,43 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
float *dst = decoder_out->GetTensorMutableData<float>();
|
||||
for (const auto &r : results) {
|
||||
if (r.decoder_out) {
|
||||
const float *src = r.decoder_out.GetTensorData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
}
|
||||
dst += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateCachedDecoderOut(
|
||||
OrtAllocator *allocator, const Ort::Value *decoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *results) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
std::array<int64_t, 2> v_shape{1, shape[1]};
|
||||
|
||||
const float *src = decoder_out->GetTensorData<float>();
|
||||
for (auto &r : *results) {
|
||||
if (!r.decoder_out) {
|
||||
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
|
||||
v_shape.size());
|
||||
}
|
||||
|
||||
float *dst = r.decoder_out.GetTensorMutableData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
src += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
UseCachedDecoderOut(*result, &decoder_out);
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
Ort::Value cur_encoder_out =
|
||||
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
}
|
||||
|
||||
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -13,6 +13,29 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<int32_t> &hyps_num_split,
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
int32_t context_size, Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
float *dst = decoder_out->GetTensorMutableData<float>();
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
|
||||
if (num_hyps > 1 || !results[i].decoder_out) {
|
||||
dst += num_hyps * shape[1];
|
||||
continue;
|
||||
}
|
||||
|
||||
const float *src = results[i].decoder_out.GetTensorData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
dst += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResult r;
|
||||
std::vector<int32_t> blanks(context_size, blank_id);
|
||||
std::vector<int64_t> blanks(context_size, blank_id);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
r.hyps = std::move(blank_hyp);
|
||||
return r;
|
||||
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
if (t == 0) {
|
||||
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
|
||||
&decoder_out);
|
||||
}
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
(*result)[b].hyps = std::move(cur[b]);
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
(*result)[b].hyps = std::move(hyps);
|
||||
(*result)[b].tokens = std::move(best_hyp.ys);
|
||||
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
|
||||
OnlineTransducerDecoderResult *result) {
|
||||
if (result->tokens.size() == model_->ContextSize()) {
|
||||
result->decoder_out = Ort::Value{nullptr};
|
||||
return;
|
||||
}
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
|
||||
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
|
||||
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
int32_t max_active_paths_;
|
||||
|
||||
@@ -21,7 +21,7 @@ static void Handler(int sig) {
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-alsa \
|
||||
@@ -30,7 +30,10 @@ Usage:
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
device_name \
|
||||
[num_threads]
|
||||
[num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -79,6 +82,11 @@ as the device_name.
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
config.enable_endpoint = true;
|
||||
|
||||
config.endpoint_config.rule1.min_trailing_silence = 2.4;
|
||||
|
||||
@@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
|
||||
}
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
if (argc < 5 || argc > 6) {
|
||||
if (argc < 5 || argc > 7) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-microphone \
|
||||
@@ -44,7 +44,10 @@ Usage:
|
||||
/path/to/encoder.onnx\
|
||||
/path/to/decoder.onnx\
|
||||
/path/to/joiner.onnx\
|
||||
[num_threads]
|
||||
[num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -70,6 +73,11 @@ for a list of pre-trained models to download.
|
||||
config.model_config.num_threads = atoi(argv[5]);
|
||||
}
|
||||
|
||||
if (argc == 7) {
|
||||
config.decoding_method = argv[6];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
config.enable_endpoint = true;
|
||||
|
||||
config.endpoint_config.rule1.min_trailing_silence = 2.4;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx \
|
||||
@@ -22,7 +22,10 @@ Usage:
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/foo.wav [num_threads]
|
||||
/path/to/foo.wav [num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -45,9 +48,15 @@ for a list of pre-trained models to download.
|
||||
std::string wav_filename = argv[5];
|
||||
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7) {
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
@@ -98,6 +107,7 @@ for a list of pre-trained models to download.
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
pybind11_add_module(_sherpa_onnx
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
endpoint.cc
|
||||
online-stream.cc
|
||||
online-recognizer.cc
|
||||
)
|
||||
|
||||
if(APPLE)
|
||||
|
||||
18
sherpa-onnx/python/csrc/display.cc
Normal file
18
sherpa-onnx/python/csrc/display.cc
Normal file
@@ -0,0 +1,18 @@
|
||||
// sherpa-onnx/python/csrc/display.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindDisplay(py::module *m) {
|
||||
using PyClass = Display;
|
||||
py::class_<PyClass>(*m, "Display")
|
||||
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
|
||||
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/display.h
Normal file
16
sherpa-onnx/python/csrc/display.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/display.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindDisplay(py::module *m);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
|
||||
@@ -11,10 +11,12 @@ namespace sherpa_onnx {
|
||||
static void PybindFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = FeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def(py::init<float, int32_t, int32_t>(),
|
||||
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
|
||||
py::arg("max_feature_vectors") = -1)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const EndpointConfig &,
|
||||
bool>(),
|
||||
bool, const std::string &, int32_t>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"))
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"),
|
||||
py::arg("decoding_method"), py::arg("max_active_paths"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
|
||||
PybindDisplay(&m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import Display
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
|
||||
@@ -32,6 +32,9 @@ class OnlineRecognizer(object):
|
||||
rule1_min_trailing_silence: int = 2.4,
|
||||
rule2_min_trailing_silence: int = 1.2,
|
||||
rule3_min_utterance_length: int = 20,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
max_feature_vectors: int = -1,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -74,6 +77,14 @@ class OnlineRecognizer(object):
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
decoding_method:
|
||||
Valid values are greedy_search, modified_beam_search.
|
||||
max_active_paths:
|
||||
Use only when decoding_method is modified_beam_search. It specifies
|
||||
the maximum number of active paths during beam search.
|
||||
max_feature_vectors:
|
||||
Number of feature vectors to cache. -1 means to cache all feature
|
||||
frames that have been processed.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -93,6 +104,7 @@ class OnlineRecognizer(object):
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
max_feature_vectors=max_feature_vectors,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
@@ -106,6 +118,8 @@ class OnlineRecognizer(object):
|
||||
model_config=model_config,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
max_active_paths=max_active_paths,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
|
||||
Reference in New Issue
Block a user