Add C++ and Python API for Dolphin CTC models (#2085)
This commit is contained in:
@@ -118,6 +118,19 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
if (!config_.model_config.dolphin.model.empty()) {
|
||||
config_.feat_config.low_freq = 0;
|
||||
config_.feat_config.high_freq = 8000;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
config_.feat_config.dither = 0;
|
||||
config_.feat_config.preemph_coeff = 0;
|
||||
config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.feature_dim = 80;
|
||||
config_.feat_config.is_librosa = true;
|
||||
config_.feat_config.frame_length_ms = 31.25; // 16000/512 = 31.25
|
||||
config_.feat_config.snip_edges = false;
|
||||
}
|
||||
|
||||
if (!config_.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
@@ -157,7 +170,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +179,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
if (!model_->SupportBatchProcessing()) {
|
||||
if (!model_->SupportBatchProcessing() || (n == 1)) {
|
||||
// If the model does not support batch process,
|
||||
// we process each stream independently.
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
@@ -190,6 +203,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
std::vector<float> f = ss[i]->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
features_vec[i] = std::move(f);
|
||||
|
||||
features_length_vec[i] = num_frames;
|
||||
@@ -241,6 +257,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
|
||||
Reference in New Issue
Block a user