// sherpa-onnx/csrc/online-stream.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-stream.h" #include #include #include #include "sherpa-onnx/csrc/features.h" namespace sherpa_onnx { class OnlineStream::Impl { public: explicit Impl(const FeatureExtractorConfig &config, ContextGraphPtr context_graph) : feat_extractor_(config), context_graph_(context_graph) {} void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); } void InputFinished() const { feat_extractor_.InputFinished(); } int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady() - start_frame_index_; } bool IsLastFrame(int32_t frame) const { return feat_extractor_.IsLastFrame(frame); } std::vector GetFrames(int32_t frame_index, int32_t n) const { return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); } void Reset() { // we don't reset the feature extractor start_frame_index_ += num_processed_frames_; num_processed_frames_ = 0; } int32_t &GetNumProcessedFrames() { return num_processed_frames_; } int32_t GetNumFramesSinceStart() const { return start_frame_index_; } int32_t &GetCurrentSegment() { return segment_; } void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } OnlineTransducerDecoderResult &GetResult() { return result_; } void SetParaformerResult(const OnlineParaformerDecoderResult &r) { paraformer_result_ = r; } OnlineParaformerDecoderResult &GetParaformerResult() { return paraformer_result_; } int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } void SetStates(std::vector states) { states_ = std::move(states); } std::vector &GetStates() { return states_; } const ContextGraphPtr &GetContextGraph() const { return context_graph_; } std::vector &GetParaformerFeatCache() { return paraformer_feat_cache_; } std::vector &GetParaformerEncoderOutCache() { return paraformer_encoder_out_cache_; } std::vector &GetParaformerAlphaCache() { return paraformer_alpha_cache_; } private: FeatureExtractor feat_extractor_; /// For contextual-biasing ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling int32_t start_frame_index_ = 0; // never reset int32_t segment_ = 0; OnlineTransducerDecoderResult result_; std::vector states_; std::vector paraformer_feat_cache_; std::vector paraformer_encoder_out_cache_; std::vector paraformer_alpha_cache_; OnlineParaformerDecoderResult paraformer_result_; }; OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, ContextGraphPtr context_graph /*= nullptr */) : impl_(std::make_unique(config, context_graph)) {} OnlineStream::~OnlineStream() = default; void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) const { impl_->AcceptWaveform(sampling_rate, waveform, n); } void OnlineStream::InputFinished() const { impl_->InputFinished(); } int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } bool OnlineStream::IsLastFrame(int32_t frame) const { return impl_->IsLastFrame(frame); } std::vector OnlineStream::GetFrames(int32_t frame_index, int32_t n) const { return impl_->GetFrames(frame_index, n); } void OnlineStream::Reset() { impl_->Reset(); } int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); } int32_t &OnlineStream::GetNumProcessedFrames() { return impl_->GetNumProcessedFrames(); } int32_t OnlineStream::GetNumFramesSinceStart() const { return impl_->GetNumFramesSinceStart(); } int32_t &OnlineStream::GetCurrentSegment() { return impl_->GetCurrentSegment(); } void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { impl_->SetResult(r); } OnlineTransducerDecoderResult &OnlineStream::GetResult() { return impl_->GetResult(); } void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { impl_->SetParaformerResult(r); } OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() { return impl_->GetParaformerResult(); } void OnlineStream::SetStates(std::vector states) { impl_->SetStates(std::move(states)); } std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } const ContextGraphPtr &OnlineStream::GetContextGraph() const { return impl_->GetContextGraph(); } std::vector &OnlineStream::GetParaformerFeatCache() { return impl_->GetParaformerFeatCache(); } std::vector &OnlineStream::GetParaformerEncoderOutCache() { return impl_->GetParaformerEncoderOutCache(); } std::vector &OnlineStream::GetParaformerAlphaCache() { return impl_->GetParaformerAlphaCache(); } } // namespace sherpa_onnx