diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 759b09c3..c89c1832 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(sherpa-onnx features.cc online-lstm-transducer-model.cc + online-recognizer.cc online-stream.cc online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 3201f4fe..da7074f8 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -7,12 +7,23 @@ #include #include #include // NOLINT +#include #include #include "kaldi-native-fbank/csrc/online-feature.h" namespace sherpa_onnx { +std::string FeatureExtractorConfig::ToString() const { + std::ostringstream os; + + os << "FeatureExtractorConfig("; + os << "sampling_rate=" << sampling_rate << ", "; + os << "feature_dim=" << feature_dim << ")"; + + return os.str(); +} + class FeatureExtractor::Impl { public: explicit Impl(const FeatureExtractorConfig &config) { diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 807bd48c..59f07188 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_FEATURES_H_ #include +#include #include namespace sherpa_onnx { @@ -13,6 +14,8 @@ namespace sherpa_onnx { struct FeatureExtractorConfig { float sampling_rate = 16000; int32_t feature_dim = 80; + + std::string ToString() const; }; class FeatureExtractor { diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc new file mode 100644 index 00000000..3a7b42cf --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -0,0 +1,136 @@ +// sherpa-onnx/csrc/online-recognizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-recognizer.h" + +#include + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table) { + std::string text; + for (auto t : src.tokens) { + text += sym_table[t]; + } + + OnlineRecognizerResult ans; + ans.text = std::move(text); + return ans; +} + +std::string OnlineRecognizerConfig::ToString() const { + std::ostringstream os; + + os << "OnlineRecognizerConfig("; + os << "feat_config=" << feat_config.ToString() << ", "; + os << "model_config=" << model_config.ToString() << ", "; + os << "tokens=\"" << tokens << "\")"; + + return os.str(); +} + +class OnlineRecognizer::Impl { + public: + explicit Impl(const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(config.model_config)), + sym_(config.tokens) { + decoder_ = + std::make_unique(model_.get()); + } + + std::unique_ptr CreateStream() const { + auto stream = std::make_unique(config_.feat_config); + stream->SetResult(decoder_->GetEmptyResult()); + stream->SetStates(model_->GetEncoderInitStates()); + return stream; + } + + bool IsReady(OnlineStream *s) const { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) { + if (n != 1) { + fprintf(stderr, "only n == 1 is implemented\n"); + exit(-1); + } + OnlineStream *s = ss[0]; + assert(IsReady(s)); + + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = s->FeatureDim(); + + std::array x_shape{1, chunk_size, feature_dim}; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector features = + s->GetFrames(s->GetNumProcessedFrames(), chunk_size); + + s->GetNumProcessedFrames() += chunk_shift; + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + + auto pair = model_->RunEncoder(std::move(x), s->GetStates()); + + s->SetStates(std::move(pair.second)); + std::vector results = {s->GetResult()}; + + decoder_->Decode(std::move(pair.first), &results); + s->SetResult(results[0]); + } + + OnlineRecognizerResult GetResult(OnlineStream *s) { + OnlineTransducerDecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); + + return Convert(decoder_result, sym_); + } + + private: + OnlineRecognizerConfig config_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; +}; + +OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) + : impl_(std::make_unique(config)) {} +OnlineRecognizer::~OnlineRecognizer() = default; + +std::unique_ptr OnlineRecognizer::CreateStream() const { + return impl_->CreateStream(); +} + +bool OnlineRecognizer::IsReady(OnlineStream *s) const { + return impl_->IsReady(s); +} + +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) { + impl_->DecodeStreams(ss, n); +} + +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) { + return impl_->GetResult(s); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h new file mode 100644 index 00000000..f9622452 --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -0,0 +1,65 @@ +// sherpa-onnx/csrc/online-recognizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ + +#include +#include + +#include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-stream.h" +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +struct OnlineRecognizerResult { + std::string text; +}; + +struct OnlineRecognizerConfig { + FeatureExtractorConfig feat_config; + OnlineTransducerModelConfig model_config; + std::string tokens; + + std::string ToString() const; +}; + +class OnlineRecognizer { + public: + explicit OnlineRecognizer(const OnlineRecognizerConfig &config); + ~OnlineRecognizer(); + + /// Create a stream for decoding. + std::unique_ptr CreateStream() const; + + /** + * Return true if the given stream has enough frames for decoding. + * Return false otherwise + */ + bool IsReady(OnlineStream *s) const; + + /** Decode a single stream. */ + void DecodeStream(OnlineStream *s) { + OnlineStream *ss[1] = {s}; + DecodeStreams(ss, 1); + } + + /** Decode multiple streams in parallel + * + * @param ss Pointer array containing streams to be decoded. + * @param n Number of streams in `ss`. + */ + void DecodeStreams(OnlineStream **ss, int32_t n); + + OnlineRecognizerResult GetResult(OnlineStream *s); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 66f835a5..e54219c7 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/online-stream.h" #include +#include #include #include "sherpa-onnx/csrc/features.h" @@ -41,10 +42,17 @@ class OnlineStream::Impl { int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } + void SetStates(std::vector states) { + states_ = std::move(states); + } + + std::vector &GetStates() { return states_; } + private: FeatureExtractor feat_extractor_; int32_t num_processed_frames_ = 0; // before subsampling OnlineTransducerDecoderResult result_; + std::vector states_; }; OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) @@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { return impl_->GetResult(); } +void OnlineStream::SetStates(std::vector states) { + impl_->SetStates(std::move(states)); +} + +std::vector &OnlineStream::GetStates() { + return impl_->GetStates(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index bf470fb7..a945aa32 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -8,6 +8,7 @@ #include #include +#include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" @@ -63,6 +64,9 @@ class OnlineStream { void SetResult(const OnlineTransducerDecoderResult &r); const OnlineTransducerDecoderResult &GetResult() const; + void SetStates(std::vector states); + std::vector &GetStates(); + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index d4dfd109..1c72fd1b 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -26,13 +26,14 @@ class OnlineTransducerDecoder { * to the beginning of the decoding result, which will be * stripped by calling `StripPrecedingBlanks()`. */ - virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; + virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0; /** Strip blanks added by `GetEmptyResult()`. * * @param r It is changed in-place. */ - virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const { + } /** Run transducer beam search given the output from the encoder model. * diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 9ef41a1e..e1aafbca 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { } OnlineTransducerDecoderResult -OnlineTransducerGreedySearchDecoder::GetEmptyResult() { +OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; @@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { } void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( - OnlineTransducerDecoderResult *r) { + OnlineTransducerDecoderResult *r) const { int32_t context_size = model_->ContextSize(); auto start = r->tokens.begin() + context_size; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index 26e35238..23b507f2 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) : model_(model) {} - OnlineTransducerDecoderResult GetEmptyResult() override; + OnlineTransducerDecoderResult GetEmptyResult() const override; - void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; void Decode(Ort::Value encoder_out, std::vector *result) override; diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 89c3098c..7b838b53 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -8,6 +8,7 @@ #include #include +#include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" @@ -35,35 +36,26 @@ for a list of pre-trained models to download. return 0; } - std::string tokens = argv[1]; - sherpa_onnx::OnlineTransducerModelConfig config; - config.debug = false; - config.encoder_filename = argv[2]; - config.decoder_filename = argv[3]; - config.joiner_filename = argv[4]; + sherpa_onnx::OnlineRecognizerConfig config; + + config.tokens = argv[1]; + + config.model_config.debug = false; + config.model_config.encoder_filename = argv[2]; + config.model_config.decoder_filename = argv[3]; + config.model_config.joiner_filename = argv[4]; + std::string wav_filename = argv[5]; - config.num_threads = 2; + config.model_config.num_threads = 2; if (argc == 7) { - config.num_threads = atoi(argv[6]); + config.model_config.num_threads = atoi(argv[6]); } fprintf(stderr, "%s\n", config.ToString().c_str()); - auto model = sherpa_onnx::OnlineTransducerModel::Create(config); + sherpa_onnx::OnlineRecognizer recognizer(config); - sherpa_onnx::SymbolTable sym(tokens); - - Ort::AllocatorWithDefaultOptions allocator; - - int32_t chunk_size = model->ChunkSize(); - int32_t chunk_shift = model->ChunkShift(); - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::vector states = model->GetEncoderInitStates(); - - float expected_sampling_rate = 16000; + float expected_sampling_rate = config.feat_config.sampling_rate; bool is_ok = false; std::vector samples = @@ -82,44 +74,21 @@ for a list of pre-trained models to download. auto begin = std::chrono::steady_clock::now(); fprintf(stderr, "Started\n"); - sherpa_onnx::OnlineStream stream; - stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); + auto s = recognizer.CreateStream(); + s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); std::vector tail_paddings( static_cast(0.2 * expected_sampling_rate)); - stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), - tail_paddings.size()); - stream.InputFinished(); + s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(), + tail_paddings.size()); + s->InputFinished(); - int32_t num_frames = stream.NumFramesReady(); - int32_t feature_dim = stream.FeatureDim(); - - std::array x_shape{1, chunk_size, feature_dim}; - - sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); - std::vector result = { - decoder.GetEmptyResult()}; - while (stream.NumFramesReady() - stream.GetNumProcessedFrames() > - chunk_size) { - std::vector features = - stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size); - stream.GetNumProcessedFrames() += chunk_shift; - - Ort::Value x = - Ort::Value::CreateTensor(memory_info, features.data(), features.size(), - x_shape.data(), x_shape.size()); - - auto pair = model->RunEncoder(std::move(x), states); - states = std::move(pair.second); - decoder.Decode(std::move(pair.first), &result); - } - decoder.StripLeadingBlanks(&result[0]); - const auto &hyp = result[0].tokens; - std::string text; - for (auto t : hyp) { - text += sym[t]; + while (recognizer.IsReady(s.get())) { + recognizer.DecodeStream(s.get()); } + std::string text = recognizer.GetResult(s.get()).text; + fprintf(stderr, "Done!\n"); fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), @@ -131,7 +100,7 @@ for a list of pre-trained models to download. .count() / 1000.; - fprintf(stderr, "num threads: %d\n", config.num_threads); + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration;