// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h // // Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ #define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ #include #include #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "sherpa-onnx/csrc/offline-whisper-model.h" #include "sherpa-onnx/csrc/spoken-language-identification-impl.h" #include "sherpa-onnx/csrc/transpose.h" namespace sherpa_onnx { class SpokenLanguageIdentificationWhisperImpl : public SpokenLanguageIdentificationImpl { public: explicit SpokenLanguageIdentificationWhisperImpl( const SpokenLanguageIdentificationConfig &config) : config_(config), model_(std::make_unique(config)) { Check(); } #if __ANDROID_API__ >= 9 SpokenLanguageIdentificationWhisperImpl( AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config) : config_(config), model_(std::make_unique(mgr, config)) { Check(); } #endif std::unique_ptr CreateStream() const override { return std::make_unique(WhisperTag{}); } std::string Compute(OfflineStream *s) const override { int32_t max_num_frames = 3000; auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); int32_t feat_dim = s->FeatureDim(); std::vector f = s->GetFrames(); int32_t num_frames = f.size() / feat_dim; // we use 50 here so that there will be some zero tail paddings if (num_frames >= max_num_frames - 50) { SHERPA_ONNX_LOGE( "Only waves less than 30 seconds are supported. We process only the " "first 30 seconds and discard the remaining data"); num_frames = max_num_frames - 50; } model_->NormalizeFeatures(f.data(), num_frames, feat_dim); // note that 1000 is an experience-value. // You can replace 1000 by other values, say, 100. // // Since we have removed the 30 seconds constraint, we need // tail_padding_frames so that whisper is able to detect the eot token. int32_t tail_padding_frames = 1000; if (config_.whisper.tail_paddings > 0) { tail_padding_frames = config_.whisper.tail_paddings; } int32_t actual_frames = std::min(num_frames + tail_padding_frames, max_num_frames); std::array shape{1, actual_frames, feat_dim}; Ort::Value mel = Ort::Value::CreateTensor( model_->Allocator(), shape.data(), shape.size()); float *p_mel = mel.GetTensorMutableData(); std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); std::fill_n(p_mel + num_frames * feat_dim, (actual_frames - num_frames) * feat_dim, 0); mel = Transpose12(model_->Allocator(), &mel); try { auto cross_kv = model_->ForwardEncoder(std::move(mel)); int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second); const auto &id2lang = model_->GetID2Lang(); if (id2lang.count(lang_id)) { return id2lang.at(lang_id); } else { SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.", lang_id); return ""; } } catch (const Ort::Exception &ex) { SHERPA_ONNX_LOGE( "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " "input frames: %d, Current tail " "paddings: %d. If you see a lot of such exceptions, please consider " "using a larger --whisper-tail-paddings", ex.what(), num_frames, tail_padding_frames); return ""; } } private: void Check() const { if (!model_->IsMultiLingual()) { SHERPA_ONNX_LOGE( "Only whisper multilingual models can be used for spoken language " "identification. Given: %s,%s", config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str()); exit(-1); } } private: SpokenLanguageIdentificationConfig config_; std::unique_ptr model_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_