Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -9,10 +9,11 @@ set(sources
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
parse-options.cc

View File

@@ -12,9 +12,16 @@ namespace sherpa_onnx {
class Display {
public:
explicit Display(int32_t max_word_per_line = 60)
: max_word_per_line_(max_word_per_line) {}
void Print(int32_t segment_id, const std::string &s) {
#ifdef _MSC_VER
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
if (segment_id != -1) {
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
} else {
fprintf(stderr, "%s\n", s.c_str());
}
return;
#endif
if (last_segment_ == segment_id) {
@@ -27,7 +34,9 @@ class Display {
num_previous_lines_ = 0;
}
fprintf(stderr, "\r%d:", segment_id);
if (segment_id != -1) {
fprintf(stderr, "\r%d:", segment_id);
}
int32_t i = 0;
for (size_t n = 0; n < s.size();) {
@@ -69,7 +78,7 @@ class Display {
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
private:
int32_t max_word_per_line_ = 60;
int32_t max_word_per_line_;
int32_t num_previous_lines_ = 0;
int32_t last_segment_ = -1;
};

View File

@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
os << "feature_dim=" << feature_dim << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
return os.str();
}
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
opts_.mel_opts.num_bins = config.feature_dim;

View File

@@ -16,6 +16,7 @@ namespace sherpa_onnx {
struct FeatureExtractorConfig {
float sampling_rate = 16000;
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
std::string ToString() const;

View File

@@ -18,7 +18,7 @@ namespace sherpa_onnx {
struct Hypothesis {
// The predicted tokens so far. Newly predicated tokens are appended.
std::vector<int32_t> ys;
std::vector<int64_t> ys;
// timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded.
@@ -30,7 +30,7 @@ struct Hypothesis {
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
// If two Hypotheses have the same `Key`, then they contain

View File

@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("decoding-mothod", &decoding_method,
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
}
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
os << "max_active_paths=" << max_active_paths << ",";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
}
void Reset(OnlineStream *s) const {
// reset result, neural network model state, and
// the feature extractor state
// reset result
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult(decoder_->GetEmptyResult());
s->GetResult().decoder_out = std::move(decoder_out);
// reset neural network model state
s->SetStates(model_->GetEncoderInitStates());
// reset feature extractor
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}

View File

@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
OnlineTransducerModelConfig model_config;
EndpointConfig endpoint_config;
bool enable_endpoint = true;
int32_t max_active_paths = 4;
std::string decoding_method = "modified_beam_search";
std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const EndpointConfig &endpoint_config,
bool enable_endpoint)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
void Register(ParseOptions *po);
bool Validate() const;

View File

@@ -22,18 +22,21 @@ class OnlineStream::Impl {
void InputFinished() { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const {
return feat_extractor_.IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
return feat_extractor_.GetFrames(frame_index, n);
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
}
void Reset() {
feat_extractor_.Reset();
// we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0;
}
@@ -41,7 +44,7 @@ class OnlineStream::Impl {
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
const OnlineTransducerDecoderResult &GetResult() const { return result_; }
OnlineTransducerDecoderResult &GetResult() { return result_; }
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
@@ -54,6 +57,7 @@ class OnlineStream::Impl {
private:
FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r);
}
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}

View File

@@ -63,7 +63,7 @@ class OnlineStream {
int32_t &GetNumProcessedFrames();
void SetResult(const OnlineTransducerDecoderResult &r);
const OnlineTransducerDecoderResult &GetResult() const;
OnlineTransducerDecoderResult &GetResult();
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();

View File

@@ -0,0 +1,60 @@
// sherpa-onnx/csrc/online-transducer-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
const OnlineTransducerDecoderResult &other)
: OnlineTransducerDecoderResult() {
*this = other;
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
const OnlineTransducerDecoderResult &other) {
if (this == &other) {
return *this;
}
tokens = other.tokens;
num_trailing_blanks = other.num_trailing_blanks;
Ort::AllocatorWithDefaultOptions allocator;
if (other.decoder_out) {
decoder_out = Clone(allocator, &other.decoder_out);
}
hyps = other.hyps;
return *this;
}
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
OnlineTransducerDecoderResult &&other)
: OnlineTransducerDecoderResult() {
*this = std::move(other);
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
OnlineTransducerDecoderResult &&other) {
if (this == &other) {
return *this;
}
tokens = std::move(other.tokens);
num_trailing_blanks = other.num_trailing_blanks;
decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps);
return *this;
}
} // namespace sherpa_onnx

View File

@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
// used only in modified beam_search
Hypotheses hyps;
OnlineTransducerDecoderResult()
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult &operator=(
const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult &operator=(
OnlineTransducerDecoderResult &&other);
};
class OnlineTransducerDecoder {
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
*/
virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0;
// used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
};
} // namespace sherpa_onnx

View File

@@ -13,6 +13,43 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<OnlineTransducerDecoderResult> &results,
Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
for (const auto &r : results) {
if (r.decoder_out) {
const float *src = r.decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
}
dst += shape[1];
}
}
static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *results) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};
const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *results) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}
float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
UseCachedDecoderOut(*result, &decoder_out);
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out =
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
}
if (emitted) {
decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
}
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
}
} // namespace sherpa_onnx

View File

@@ -13,6 +13,29 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<int32_t> &hyps_num_split,
const std::vector<OnlineTransducerDecoderResult> &results,
int32_t context_size, Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(results.size());
for (int32_t i = 0; i != batch_size; ++i) {
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
if (num_hyps > 1 || !results[i].decoder_out) {
dst += num_hyps * shape[1];
continue;
}
const float *src = results[i].decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
dst += shape[1];
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
std::vector<int32_t> blanks(context_size, blank_id);
std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
if (t == 0) {
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
&decoder_out);
}
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
for (int32_t b = 0; b != batch_size; ++b) {
(*result)[b].hyps = std::move(cur[b]);
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
}
}
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
OnlineTransducerDecoderResult *result) {
if (result->tokens.size() == model_->ContextSize()) {
result->decoder_out = Ort::Value{nullptr};
return;
}
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
}
} // namespace sherpa_onnx

View File

@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private:
OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_;

View File

@@ -21,7 +21,7 @@ static void Handler(int sig) {
}
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-alsa \
@@ -30,7 +30,10 @@ Usage:
/path/to/decoder.onnx \
/path/to/joiner.onnx \
device_name \
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -79,6 +82,11 @@ as the device_name.
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
}
int32_t main(int32_t argc, char *argv[]) {
if (argc < 5 || argc > 6) {
if (argc < 5 || argc > 7) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-microphone \
@@ -44,7 +44,10 @@ Usage:
/path/to/encoder.onnx\
/path/to/decoder.onnx\
/path/to/joiner.onnx\
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -70,6 +73,11 @@ for a list of pre-trained models to download.
config.model_config.num_threads = atoi(argv[5]);
}
if (argc == 7) {
config.decoding_method = argv[6];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -14,7 +14,7 @@
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
@@ -22,7 +22,10 @@ Usage:
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/foo.wav [num_threads]
/path/to/foo.wav [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -45,9 +48,15 @@ for a list of pre-trained models to download.
std::string wav_filename = argv[5];
config.model_config.num_threads = 2;
if (argc == 7) {
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
fprintf(stderr, "%s\n", config.ToString().c_str());
sherpa_onnx::OnlineRecognizer recognizer(config);
@@ -98,6 +107,7 @@ for a list of pre-trained models to download.
1000.;
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;