Add C++ and Python API for Dolphin CTC models (#2085)

This commit is contained in:
Fangjun Kuang
2025-04-02 19:09:00 +08:00
committed by GitHub
parent 1316719e23
commit 0de7e1b9f0
27 changed files with 671 additions and 26 deletions

View File

@@ -118,6 +118,19 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}
if (!config_.model_config.dolphin.model.empty()) {
config_.feat_config.low_freq = 0;
config_.feat_config.high_freq = 8000;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.dither = 0;
config_.feat_config.preemph_coeff = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.feature_dim = 80;
config_.feat_config.is_librosa = true;
config_.feat_config.frame_length_ms = 31.25; // 16000/512 = 31.25
config_.feat_config.snip_edges = false;
}
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
@@ -157,7 +170,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config_.decoding_method.c_str());
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
}
@@ -166,7 +179,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (!model_->SupportBatchProcessing()) {
if (!model_->SupportBatchProcessing() || (n == 1)) {
// If the model does not support batch process,
// we process each stream independently.
for (int32_t i = 0; i != n; ++i) {
@@ -190,6 +203,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
std::vector<float> f = ss[i]->GetFrames();
int32_t num_frames = f.size() / feat_dim;
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
features_vec[i] = std::move(f);
features_length_vec[i] = num_frames;
@@ -241,6 +257,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
int32_t num_frames = f.size() / feat_dim;
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),