Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
This commit is contained in:
@@ -13,8 +13,9 @@ namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config)
|
||||
: feat_extractor_(config) {}
|
||||
explicit Impl(const FeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: feat_extractor_(config), context_graph_(context_graph) {}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
@@ -54,16 +55,21 @@ class OnlineStream::Impl {
|
||||
|
||||
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
ContextGraphPtr context_graph_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
OnlineTransducerDecoderResult result_;
|
||||
std::vector<Ort::Value> states_;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr */)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OnlineStream::~OnlineStream() = default;
|
||||
|
||||
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||
return impl_->GetStates();
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user