Add non-streaming ASR (#92)
This commit is contained in:
@@ -19,7 +19,9 @@ namespace sherpa_onnx {
|
||||
void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. Must match the one "
|
||||
"expected by the model.");
|
||||
"expected by the model. Note: You can have a different "
|
||||
"sample rate for the input waveform. We will do resampling "
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
@@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ", ";
|
||||
os << "max_feature_vectors=" << max_feature_vectors << ")";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -43,8 +44,6 @@ class FeatureExtractor::Impl {
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
|
||||
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
@@ -95,7 +94,7 @@ class FeatureExtractor::Impl {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() {
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_->InputFinished();
|
||||
}
|
||||
@@ -110,12 +109,21 @@ class FeatureExtractor::Impl {
|
||||
return fbank_->IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
if (frame_index + n > NumFramesReady()) {
|
||||
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (frame_index + n > fbank_->NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
|
||||
fbank_->NumFramesReady());
|
||||
exit(-1);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
int32_t discard_num = frame_index - last_frame_index_;
|
||||
if (discard_num < 0) {
|
||||
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
||||
last_frame_index_, frame_index);
|
||||
exit(-1);
|
||||
}
|
||||
fbank_->Pop(discard_num);
|
||||
|
||||
int32_t feature_dim = fbank_->Dim();
|
||||
std::vector<float> features(feature_dim * n);
|
||||
@@ -128,12 +136,9 @@ class FeatureExtractor::Impl {
|
||||
p += feature_dim;
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
last_frame_index_ = frame_index;
|
||||
|
||||
void Reset() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
return features;
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
@@ -143,6 +148,7 @@ class FeatureExtractor::Impl {
|
||||
knf::FbankOptions opts_;
|
||||
mutable std::mutex mutex_;
|
||||
std::unique_ptr<LinearResample> resampler_;
|
||||
int32_t last_frame_index_ = 0;
|
||||
};
|
||||
|
||||
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||
@@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||
FeatureExtractor::~FeatureExtractor() = default;
|
||||
|
||||
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
|
||||
const float *waveform, int32_t n) {
|
||||
const float *waveform, int32_t n) const {
|
||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
|
||||
void FeatureExtractor::InputFinished() const { impl_->InputFinished(); }
|
||||
|
||||
int32_t FeatureExtractor::NumFramesReady() const {
|
||||
return impl_->NumFramesReady();
|
||||
@@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
||||
return impl_->GetFrames(frame_index, n);
|
||||
}
|
||||
|
||||
void FeatureExtractor::Reset() { impl_->Reset(); }
|
||||
|
||||
int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user