Support zipformer CTC ASR with whisper features. (#2319)
This commit is contained in:
@@ -19,34 +19,51 @@ class OnlineStream::Impl {
|
||||
: feat_extractor_(config), context_graph_(std::move(context_graph)) {}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() const { feat_extractor_.InputFinished(); }
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
feat_extractor_.InputFinished();
|
||||
}
|
||||
|
||||
int32_t NumFramesReady() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||
}
|
||||
|
||||
bool IsLastFrame(int32_t frame) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// we don't reset the feature extractor
|
||||
start_frame_index_ += num_processed_frames_;
|
||||
num_processed_frames_ = 0;
|
||||
}
|
||||
|
||||
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
||||
int32_t &GetNumProcessedFrames() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return num_processed_frames_;
|
||||
}
|
||||
|
||||
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
|
||||
int32_t GetNumFramesSinceStart() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return start_frame_index_;
|
||||
}
|
||||
|
||||
int32_t &GetCurrentSegment() { return segment_; }
|
||||
int32_t &GetCurrentSegment() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return segment_;
|
||||
}
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||
|
||||
@@ -125,6 +142,7 @@ class OnlineStream::Impl {
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
mutable std::mutex mutex_;
|
||||
/// For contextual-biasing
|
||||
ContextGraphPtr context_graph_;
|
||||
int32_t num_processed_frames_ = 0; // before subsampling
|
||||
|
||||
Reference in New Issue
Block a user