Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
This commit is contained in:
Wei Kang
2023-06-30 16:46:24 +08:00
committed by GitHub
parent b2e0c4c9c2
commit 513dfaa552
10 changed files with 238 additions and 22 deletions

View File

@@ -13,8 +13,9 @@ namespace sherpa_onnx {
class OnlineStream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config)
: feat_extractor_(config) {}
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: feat_extractor_(config), context_graph_(context_graph) {}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
std::vector<Ort::Value> &GetStates() { return states_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
ContextGraphPtr context_graph_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr */)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OnlineStream::~OnlineStream() = default;
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
} // namespace sherpa_onnx