add online-recognizer (#29)

This commit is contained in:
Fangjun Kuang
2023-02-19 12:45:38 +08:00
committed by GitHub
parent d4b0c0590a
commit ebc3b47fb8
11 changed files with 267 additions and 61 deletions

View File

@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(sherpa-onnx add_executable(sherpa-onnx
features.cc features.cc
online-lstm-transducer-model.cc online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc online-stream.cc
online-transducer-greedy-search-decoder.cc online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc online-transducer-model-config.cc

View File

@@ -7,12 +7,23 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <sstream>
#include <vector> #include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h" #include "kaldi-native-fbank/csrc/online-feature.h"
namespace sherpa_onnx { namespace sherpa_onnx {
std::string FeatureExtractorConfig::ToString() const {
std::ostringstream os;
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
return os.str();
}
class FeatureExtractor::Impl { class FeatureExtractor::Impl {
public: public:
explicit Impl(const FeatureExtractorConfig &config) { explicit Impl(const FeatureExtractorConfig &config) {

View File

@@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_FEATURES_H_ #define SHERPA_ONNX_CSRC_FEATURES_H_
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -13,6 +14,8 @@ namespace sherpa_onnx {
struct FeatureExtractorConfig { struct FeatureExtractorConfig {
float sampling_rate = 16000; float sampling_rate = 16000;
int32_t feature_dim = 80; int32_t feature_dim = 80;
std::string ToString() const;
}; };
class FeatureExtractor { class FeatureExtractor {

View File

@@ -0,0 +1,136 @@
// sherpa-onnx/csrc/online-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-recognizer.h"
#include <assert.h>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table) {
std::string text;
for (auto t : src.tokens) {
text += sym_table[t];
}
OnlineRecognizerResult ans;
ans.text = std::move(text);
return ans;
}
std::string OnlineRecognizerConfig::ToString() const {
std::ostringstream os;
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\")";
return os.str();
}
class OnlineRecognizer::Impl {
public:
explicit Impl(const OnlineRecognizerConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.tokens) {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
}
std::unique_ptr<OnlineStream> CreateStream() const {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetResult(decoder_->GetEmptyResult());
stream->SetStates(model_->GetEncoderInitStates());
return stream;
}
bool IsReady(OnlineStream *s) const {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) {
if (n != 1) {
fprintf(stderr, "only n == 1 is implemented\n");
exit(-1);
}
OnlineStream *s = ss[0];
assert(IsReady(s));
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = s->FeatureDim();
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
s->GetNumProcessedFrames() += chunk_shift;
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
s->SetStates(std::move(pair.second));
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
decoder_->Decode(std::move(pair.first), &results);
s->SetResult(results[0]);
}
OnlineRecognizerResult GetResult(OnlineStream *s) {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_);
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_;
};
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OnlineRecognizer::~OnlineRecognizer() = default;
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
impl_->DecodeStreams(ss, n);
}
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
return impl_->GetResult(s);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,65 @@
// sherpa-onnx/csrc/online-recognizer.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
struct OnlineRecognizerResult {
std::string text;
};
struct OnlineRecognizerConfig {
FeatureExtractorConfig feat_config;
OnlineTransducerModelConfig model_config;
std::string tokens;
std::string ToString() const;
};
class OnlineRecognizer {
public:
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
~OnlineRecognizer();
/// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
*/
bool IsReady(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) {
OnlineStream *ss[1] = {s};
DecodeStreams(ss, 1);
}
/** Decode multiple streams in parallel
*
* @param ss Pointer array containing streams to be decoded.
* @param n Number of streams in `ss`.
*/
void DecodeStreams(OnlineStream **ss, int32_t n);
OnlineRecognizerResult GetResult(OnlineStream *s);
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-stream.h"
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/features.h"
@@ -41,10 +42,17 @@ class OnlineStream::Impl {
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
void SetStates(std::vector<Ort::Value> states) {
states_ = std::move(states);
}
std::vector<Ort::Value> &GetStates() { return states_; }
private: private:
FeatureExtractor feat_extractor_; FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling int32_t num_processed_frames_ = 0; // before subsampling
OnlineTransducerDecoderResult result_; OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
}; };
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
@@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
return impl_->GetResult(); return impl_->GetResult();
} }
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
impl_->SetStates(std::move(states));
}
std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -8,6 +8,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h"
@@ -63,6 +64,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r); void SetResult(const OnlineTransducerDecoderResult &r);
const OnlineTransducerDecoderResult &GetResult() const; const OnlineTransducerDecoderResult &GetResult() const;
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
private: private:
class Impl; class Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;

View File

@@ -26,13 +26,14 @@ class OnlineTransducerDecoder {
* to the beginning of the decoding result, which will be * to the beginning of the decoding result, which will be
* stripped by calling `StripPrecedingBlanks()`. * stripped by calling `StripPrecedingBlanks()`.
*/ */
virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0;
/** Strip blanks added by `GetEmptyResult()`. /** Strip blanks added by `GetEmptyResult()`.
* *
* @param r It is changed in-place. * @param r It is changed in-place.
*/ */
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {
}
/** Run transducer beam search given the output from the encoder model. /** Run transducer beam search given the output from the encoder model.
* *

View File

@@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
} }
OnlineTransducerDecoderResult OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() { OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize(); int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0 int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r; OnlineTransducerDecoderResult r;
@@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
} }
void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
OnlineTransducerDecoderResult *r) { OnlineTransducerDecoderResult *r) const {
int32_t context_size = model_->ContextSize(); int32_t context_size = model_->ContextSize();
auto start = r->tokens.begin() + context_size; auto start = r->tokens.begin() + context_size;

View File

@@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
: model_(model) {} : model_(model) {}
OnlineTransducerDecoderResult GetEmptyResult() override; OnlineTransducerDecoderResult GetEmptyResult() const override;
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override;
void Decode(Ort::Value encoder_out, void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override; std::vector<OnlineTransducerDecoderResult> *result) override;

View File

@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
@@ -35,35 +36,26 @@ for a list of pre-trained models to download.
return 0; return 0;
} }
std::string tokens = argv[1]; sherpa_onnx::OnlineRecognizerConfig config;
sherpa_onnx::OnlineTransducerModelConfig config;
config.debug = false; config.tokens = argv[1];
config.encoder_filename = argv[2];
config.decoder_filename = argv[3]; config.model_config.debug = false;
config.joiner_filename = argv[4]; config.model_config.encoder_filename = argv[2];
config.model_config.decoder_filename = argv[3];
config.model_config.joiner_filename = argv[4];
std::string wav_filename = argv[5]; std::string wav_filename = argv[5];
config.num_threads = 2; config.model_config.num_threads = 2;
if (argc == 7) { if (argc == 7) {
config.num_threads = atoi(argv[6]); config.model_config.num_threads = atoi(argv[6]);
} }
fprintf(stderr, "%s\n", config.ToString().c_str()); fprintf(stderr, "%s\n", config.ToString().c_str());
auto model = sherpa_onnx::OnlineTransducerModel::Create(config); sherpa_onnx::OnlineRecognizer recognizer(config);
sherpa_onnx::SymbolTable sym(tokens); float expected_sampling_rate = config.feat_config.sampling_rate;
Ort::AllocatorWithDefaultOptions allocator;
int32_t chunk_size = model->ChunkSize();
int32_t chunk_shift = model->ChunkShift();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> states = model->GetEncoderInitStates();
float expected_sampling_rate = 16000;
bool is_ok = false; bool is_ok = false;
std::vector<float> samples = std::vector<float> samples =
@@ -82,44 +74,21 @@ for a list of pre-trained models to download.
auto begin = std::chrono::steady_clock::now(); auto begin = std::chrono::steady_clock::now();
fprintf(stderr, "Started\n"); fprintf(stderr, "Started\n");
sherpa_onnx::OnlineStream stream; auto s = recognizer.CreateStream();
stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings( std::vector<float> tail_paddings(
static_cast<int>(0.2 * expected_sampling_rate)); static_cast<int>(0.2 * expected_sampling_rate));
stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
tail_paddings.size()); tail_paddings.size());
stream.InputFinished(); s->InputFinished();
int32_t num_frames = stream.NumFramesReady(); while (recognizer.IsReady(s.get())) {
int32_t feature_dim = stream.FeatureDim(); recognizer.DecodeStream(s.get());
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
decoder.GetEmptyResult()};
while (stream.NumFramesReady() - stream.GetNumProcessedFrames() >
chunk_size) {
std::vector<float> features =
stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size);
stream.GetNumProcessedFrames() += chunk_shift;
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
auto pair = model->RunEncoder(std::move(x), states);
states = std::move(pair.second);
decoder.Decode(std::move(pair.first), &result);
}
decoder.StripLeadingBlanks(&result[0]);
const auto &hyp = result[0].tokens;
std::string text;
for (auto t : hyp) {
text += sym[t];
} }
std::string text = recognizer.GetResult(s.get()).text;
fprintf(stderr, "Done!\n"); fprintf(stderr, "Done!\n");
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
@@ -131,7 +100,7 @@ for a list of pre-trained models to download.
.count() / .count() /
1000.; 1000.;
fprintf(stderr, "num threads: %d\n", config.num_threads); fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration; float rtf = elapsed_seconds / duration;