add online-recognizer (#29)
This commit is contained in:
@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
|
||||
add_executable(sherpa-onnx
|
||||
features.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
online-transducer-model-config.cc
|
||||
|
||||
@@ -7,12 +7,23 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string FeatureExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class FeatureExtractor::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config) {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -13,6 +14,8 @@ namespace sherpa_onnx {
|
||||
struct FeatureExtractorConfig {
|
||||
float sampling_rate = 16000;
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class FeatureExtractor {
|
||||
|
||||
136
sherpa-onnx/csrc/online-recognizer.cc
Normal file
136
sherpa-onnx/csrc/online-recognizer.cc
Normal file
@@ -0,0 +1,136 @@
|
||||
// sherpa-onnx/csrc/online-recognizer.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
std::string text;
|
||||
for (auto t : src.tokens) {
|
||||
text += sym_table[t];
|
||||
}
|
||||
|
||||
OnlineRecognizerResult ans;
|
||||
ans.text = std::move(text);
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class OnlineRecognizer::Impl {
|
||||
public:
|
||||
explicit Impl(const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.tokens) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetResult(decoder_->GetEmptyResult());
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
if (n != 1) {
|
||||
fprintf(stderr, "only n == 1 is implemented\n");
|
||||
exit(-1);
|
||||
}
|
||||
OnlineStream *s = ss[0];
|
||||
assert(IsReady(s));
|
||||
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feature_dim = s->FeatureDim();
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
|
||||
|
||||
s->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
|
||||
|
||||
s->SetStates(std::move(pair.second));
|
||||
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
s->SetResult(results[0]);
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
return Convert(decoder_result, sym_);
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
};
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
OnlineRecognizer::~OnlineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
|
||||
return impl_->GetResult(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
65
sherpa-onnx/csrc/online-recognizer.h
Normal file
65
sherpa-onnx/csrc/online-recognizer.h
Normal file
@@ -0,0 +1,65 @@
|
||||
// sherpa-onnx/csrc/online-recognizer.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineRecognizerResult {
|
||||
std::string text;
|
||||
};
|
||||
|
||||
struct OnlineRecognizerConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineTransducerModelConfig model_config;
|
||||
std::string tokens;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OnlineRecognizer {
|
||||
public:
|
||||
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
||||
~OnlineRecognizer();
|
||||
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
* Return false otherwise
|
||||
*/
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) {
|
||||
OnlineStream *ss[1] = {s};
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
|
||||
/** Decode multiple streams in parallel
|
||||
*
|
||||
* @param ss Pointer array containing streams to be decoded.
|
||||
* @param n Number of streams in `ss`.
|
||||
*/
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n);
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s);
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
@@ -41,10 +42,17 @@ class OnlineStream::Impl {
|
||||
|
||||
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
||||
|
||||
void SetStates(std::vector<Ort::Value> states) {
|
||||
states_ = std::move(states);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||
@@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
|
||||
impl_->SetStates(std::move(states));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||
return impl_->GetStates();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
|
||||
@@ -63,6 +64,9 @@ class OnlineStream {
|
||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||
const OnlineTransducerDecoderResult &GetResult() const;
|
||||
|
||||
void SetStates(std::vector<Ort::Value> states);
|
||||
std::vector<Ort::Value> &GetStates();
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -26,13 +26,14 @@ class OnlineTransducerDecoder {
|
||||
* to the beginning of the decoding result, which will be
|
||||
* stripped by calling `StripPrecedingBlanks()`.
|
||||
*/
|
||||
virtual OnlineTransducerDecoderResult GetEmptyResult() = 0;
|
||||
virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0;
|
||||
|
||||
/** Strip blanks added by `GetEmptyResult()`.
|
||||
*
|
||||
* @param r It is changed in-place.
|
||||
*/
|
||||
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {}
|
||||
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {
|
||||
}
|
||||
|
||||
/** Run transducer beam search given the output from the encoder model.
|
||||
*
|
||||
|
||||
@@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResult r;
|
||||
@@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
|
||||
}
|
||||
|
||||
void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
||||
OnlineTransducerDecoderResult *r) {
|
||||
OnlineTransducerDecoderResult *r) const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
auto start = r->tokens.begin() + context_size;
|
||||
|
||||
@@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
|
||||
: model_(model) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() override;
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override;
|
||||
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override;
|
||||
|
||||
void Decode(Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
@@ -35,35 +36,26 @@ for a list of pre-trained models to download.
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string tokens = argv[1];
|
||||
sherpa_onnx::OnlineTransducerModelConfig config;
|
||||
config.debug = false;
|
||||
config.encoder_filename = argv[2];
|
||||
config.decoder_filename = argv[3];
|
||||
config.joiner_filename = argv[4];
|
||||
sherpa_onnx::OnlineRecognizerConfig config;
|
||||
|
||||
config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
config.model_config.decoder_filename = argv[3];
|
||||
config.model_config.joiner_filename = argv[4];
|
||||
|
||||
std::string wav_filename = argv[5];
|
||||
|
||||
config.num_threads = 2;
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7) {
|
||||
config.num_threads = atoi(argv[6]);
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
|
||||
sherpa_onnx::SymbolTable sym(tokens);
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
int32_t chunk_size = model->ChunkSize();
|
||||
int32_t chunk_shift = model->ChunkShift();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> states = model->GetEncoderInitStates();
|
||||
|
||||
float expected_sampling_rate = 16000;
|
||||
float expected_sampling_rate = config.feat_config.sampling_rate;
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
@@ -82,44 +74,21 @@ for a list of pre-trained models to download.
|
||||
auto begin = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
sherpa_onnx::OnlineStream stream;
|
||||
stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||
auto s = recognizer.CreateStream();
|
||||
s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||
|
||||
std::vector<float> tail_paddings(
|
||||
static_cast<int>(0.2 * expected_sampling_rate));
|
||||
stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
s->InputFinished();
|
||||
|
||||
int32_t num_frames = stream.NumFramesReady();
|
||||
int32_t feature_dim = stream.FeatureDim();
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||
|
||||
sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
|
||||
std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
|
||||
decoder.GetEmptyResult()};
|
||||
while (stream.NumFramesReady() - stream.GetNumProcessedFrames() >
|
||||
chunk_size) {
|
||||
std::vector<float> features =
|
||||
stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size);
|
||||
stream.GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
|
||||
auto pair = model->RunEncoder(std::move(x), states);
|
||||
states = std::move(pair.second);
|
||||
decoder.Decode(std::move(pair.first), &result);
|
||||
}
|
||||
decoder.StripLeadingBlanks(&result[0]);
|
||||
const auto &hyp = result[0].tokens;
|
||||
std::string text;
|
||||
for (auto t : hyp) {
|
||||
text += sym[t];
|
||||
while (recognizer.IsReady(s.get())) {
|
||||
recognizer.DecodeStream(s.get());
|
||||
}
|
||||
|
||||
std::string text = recognizer.GetResult(s.get()).text;
|
||||
|
||||
fprintf(stderr, "Done!\n");
|
||||
|
||||
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
||||
@@ -131,7 +100,7 @@ for a list of pre-trained models to download.
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
|
||||
Reference in New Issue
Block a user