diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index f2b5fbd7..759b09c3 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(sherpa-onnx features.cc online-lstm-transducer-model.cc + online-stream.cc online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc online-transducer-model.cc diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index e65a64ba..807bd48c 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -11,16 +11,12 @@ namespace sherpa_onnx { struct FeatureExtractorConfig { - int32_t sampling_rate = 16000; + float sampling_rate = 16000; int32_t feature_dim = 80; }; class FeatureExtractor { public: - /** - * @param sampling_rate Sampling rate of the data used to train the model. - * @param feature_dim Dimension of the features used to train the model. - */ explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); ~FeatureExtractor(); @@ -32,16 +28,19 @@ class FeatureExtractor { */ void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); - // InputFinished() tells the class you won't be providing any - // more waveform. This will help flush out the last frame or two - // of features, in the case where snip-edges == false; it also - // affects the return value of IsLastFrame(). + /** + * InputFinished() tells the class you won't be providing any + * more waveform. This will help flush out the last frame or two + * of features, in the case where snip-edges == false; it also + * affects the return value of IsLastFrame(). + */ void InputFinished(); int32_t NumFramesReady() const; - // Note: IsLastFrame() will only ever return true if you have called - // InputFinished() (and this frame is the last frame). + /** Note: IsLastFrame() will only ever return true if you have called + * InputFinished() (and this frame is the last frame). + */ bool IsLastFrame(int32_t frame) const; /** Get n frames starting from the given frame index. diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc new file mode 100644 index 00000000..66f835a5 --- /dev/null +++ b/sherpa-onnx/csrc/online-stream.cc @@ -0,0 +1,89 @@ +// sherpa-onnx/csrc/online-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/online-stream.h" + +#include +#include + +#include "sherpa-onnx/csrc/features.h" + +namespace sherpa_onnx { + +class OnlineStream::Impl { + public: + explicit Impl(const FeatureExtractorConfig &config) + : feat_extractor_(config) {} + + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { + feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); + } + + void InputFinished() { feat_extractor_.InputFinished(); } + + int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } + + bool IsLastFrame(int32_t frame) const { + return feat_extractor_.IsLastFrame(frame); + } + + std::vector GetFrames(int32_t frame_index, int32_t n) const { + return feat_extractor_.GetFrames(frame_index, n); + } + + void Reset() { feat_extractor_.Reset(); } + + int32_t &GetNumProcessedFrames() { return num_processed_frames_; } + + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } + + const OnlineTransducerDecoderResult &GetResult() const { return result_; } + + int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } + + private: + FeatureExtractor feat_extractor_; + int32_t num_processed_frames_ = 0; // before subsampling + OnlineTransducerDecoderResult result_; +}; + +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) + : impl_(std::make_unique(config)) {} + +OnlineStream::~OnlineStream() = default; + +void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform, + int32_t n) { + impl_->AcceptWaveform(sampling_rate, waveform, n); +} + +void OnlineStream::InputFinished() { impl_->InputFinished(); } + +int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } + +bool OnlineStream::IsLastFrame(int32_t frame) const { + return impl_->IsLastFrame(frame); +} + +std::vector OnlineStream::GetFrames(int32_t frame_index, + int32_t n) const { + return impl_->GetFrames(frame_index, n); +} + +void OnlineStream::Reset() { impl_->Reset(); } + +int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); } + +int32_t &OnlineStream::GetNumProcessedFrames() { + return impl_->GetNumProcessedFrames(); +} + +void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { + impl_->SetResult(r); +} + +const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { + return impl_->GetResult(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h new file mode 100644 index 00000000..bf470fb7 --- /dev/null +++ b/sherpa-onnx/csrc/online-stream.h @@ -0,0 +1,73 @@ +// sherpa-onnx/csrc/online-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ + +#include +#include + +#include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" + +namespace sherpa_onnx { + +class OnlineStream { + public: + explicit OnlineStream(const FeatureExtractorConfig &config = {}); + ~OnlineStream(); + + /** + @param sampling_rate The sampling_rate of the input waveform. Should match + the one expected by the feature extractor. + @param waveform Pointer to a 1-D array of size n + @param n Number of entries in waveform + */ + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); + + /** + * InputFinished() tells the class you won't be providing any + * more waveform. This will help flush out the last frame or two + * of features, in the case where snip-edges == false; it also + * affects the return value of IsLastFrame(). + */ + void InputFinished(); + + int32_t NumFramesReady() const; + + /** Note: IsLastFrame() will only ever return true if you have called + * InputFinished() (and this frame is the last frame). + */ + bool IsLastFrame(int32_t frame) const; + + /** Get n frames starting from the given frame index. + * + * @param frame_index The starting frame index + * @param n Number of frames to get. + * @return Return a 2-D tensor of shape (n, feature_dim). + * which is flattened into a 1-D vector (flattened in in row major) + */ + std::vector GetFrames(int32_t frame_index, int32_t n) const; + + void Reset(); + + int32_t FeatureDim() const; + + // Return a reference to the number of processed frames so far. + // Initially, it is 0. It is always less than NumFramesReady(). + // + // The returned reference is valid as long as this object is alive. + int32_t &GetNumProcessedFrames(); + + void SetResult(const OnlineTransducerDecoderResult &r); + const OnlineTransducerDecoderResult &GetResult() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 6338317f..89c3098c 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -8,8 +8,7 @@ #include #include -#include "kaldi-native-fbank/csrc/online-feature.h" -#include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" @@ -64,7 +63,7 @@ for a list of pre-trained models to download. std::vector states = model->GetEncoderInitStates(); - int32_t expected_sampling_rate = 16000; + float expected_sampling_rate = 16000; bool is_ok = false; std::vector samples = @@ -75,7 +74,7 @@ for a list of pre-trained models to download. return -1; } - float duration = samples.size() / static_cast(expected_sampling_rate); + float duration = samples.size() / expected_sampling_rate; fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); fprintf(stderr, "wav duration (s): %.3f\n", duration); @@ -83,32 +82,33 @@ for a list of pre-trained models to download. auto begin = std::chrono::steady_clock::now(); fprintf(stderr, "Started\n"); - sherpa_onnx::FeatureExtractor feat_extractor; - feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), - samples.size()); + sherpa_onnx::OnlineStream stream; + stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); std::vector tail_paddings( static_cast(0.2 * expected_sampling_rate)); - feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), - tail_paddings.size()); - feat_extractor.InputFinished(); + stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); - int32_t num_frames = feat_extractor.NumFramesReady(); - int32_t feature_dim = feat_extractor.FeatureDim(); + int32_t num_frames = stream.NumFramesReady(); + int32_t feature_dim = stream.FeatureDim(); std::array x_shape{1, chunk_size, feature_dim}; sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); std::vector result = { decoder.GetEmptyResult()}; - - for (int32_t start = 0; start + chunk_size < num_frames; - start += chunk_shift) { - std::vector features = feat_extractor.GetFrames(start, chunk_size); + while (stream.NumFramesReady() - stream.GetNumProcessedFrames() > + chunk_size) { + std::vector features = + stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size); + stream.GetNumProcessedFrames() += chunk_shift; Ort::Value x = Ort::Value::CreateTensor(memory_info, features.data(), features.size(), x_shape.data(), x_shape.size()); + auto pair = model->RunEncoder(std::move(x), states); states = std::move(pair.second); decoder.Decode(std::move(pair.first), &result); @@ -116,8 +116,8 @@ for a list of pre-trained models to download. decoder.StripLeadingBlanks(&result[0]); const auto &hyp = result[0].tokens; std::string text; - for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { - text += sym[hyp[i]]; + for (auto t : hyp) { + text += sym[t]; } fprintf(stderr, "Done!\n");