diff --git a/kotlin-api-examples/OfflineSpeechDenoiser.kt b/kotlin-api-examples/OfflineSpeechDenoiser.kt new file mode 120000 index 00000000..bc25f609 --- /dev/null +++ b/kotlin-api-examples/OfflineSpeechDenoiser.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt \ No newline at end of file diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index f7ebd110..f0614ae5 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() { java -Djava.library.path=../build/lib -jar $out_filename } +function testOfflineSpeechDenoiser() { + if [ ! -f ./gtcrn_simple.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx + fi + + if [ ! -f ./inp_16k.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav + fi + + out_filename=test_offline_speech_denoiser.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_offline_speech_denoiser.kt \ + OfflineSpeechDenoiser.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename + + ls -lh *.wav +} + +testOfflineSpeechDenoiser testOfflineSpeakerDiarization testSpeakerEmbeddingExtractor testOnlineAsr diff --git a/kotlin-api-examples/test_offline_speech_denoiser.kt b/kotlin-api-examples/test_offline_speech_denoiser.kt new file mode 100644 index 00000000..1d767c21 --- /dev/null +++ b/kotlin-api-examples/test_offline_speech_denoiser.kt @@ -0,0 +1,41 @@ +package com.k2fsa.sherpa.onnx +// Please download test files in this script from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models + +fun main() { + test() +} + +fun test() { + val denoiser = createOfflineSpeechDenoiser() + + val waveFilename = "./inp_16k.wav"; + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val denoised = denoiser.run(samples, sampleRate); + denoised.save(filename="./enhanced-16k.wav") + println("saved to ./enhanced-16k.wav") +} + +fun createOfflineSpeechDenoiser(): OfflineSpeechDenoiser { + val config = OfflineSpeechDenoiserConfig( + model = OfflineSpeechDenoiserModelConfig( + gtcrn = OfflineSpeechDenoiserGtcrnModelConfig( + model = "./gtcrn_simple.onnx" + ), + provider = "cpu", + numThreads = 1, + ), + ) + + println(config) + + return OfflineSpeechDenoiser(config = config) +} + + diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index a18bc93b..9e1c4952 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -16,6 +16,7 @@ set(sources keyword-spotter.cc offline-punctuation.cc offline-recognizer.cc + offline-speech-denoiser.cc offline-stream.cc online-punctuation.cc online-recognizer.cc diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 955b4e30..42fd0d72 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) { return env->NewObject(cls, constructor, value); } -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( - JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, - jint sample_rate) { - const char *p_filename = env->GetStringUTFChars(filename, nullptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - - bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); - - env->ReleaseStringUTFChars(filename, p_filename); - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); - - return ok; -} - #if 0 SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL diff --git a/sherpa-onnx/jni/offline-speech-denoiser.cc b/sherpa-onnx/jni/offline-speech-denoiser.cc new file mode 100644 index 00000000..d1377c8d --- /dev/null +++ b/sherpa-onnx/jni/offline-speech-denoiser.cc @@ -0,0 +1,158 @@ +// sherpa-onnx/jni/offline-speech-denoiser.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { +static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig( + JNIEnv *env, jobject config) { + OfflineSpeechDenoiserConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID( + cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;"); + jobject model = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model); + + fid = env->GetFieldID( + model_config_cls, "gtcrn", + "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;"); + jobject gtcrn = env->GetObjectField(model, fid); + jclass gtcrn_cls = env->GetObjectClass(gtcrn); + + fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(gtcrn, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.gtcrn.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)speech_denoiser; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env, + jobject /*obj*/, + jobject _config) { + return SafeJNI( + env, "OfflineSpeechDenoiser_newFromFile", + [&]() -> jlong { + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config); + return reinterpret_cast(speech_denoiser); + }, + 0L); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + return reinterpret_cast(ptr) + ->GetSampleRate(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto speech_denoiser = + reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + auto denoised = speech_denoiser->Run(p, n, sample_rate); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio"); + if (cls == nullptr) { + SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio"); + return nullptr; + } + + // https://javap.yawk.at/ + jmethodID constructor = env->GetMethodID(cls, "", "([FI)V"); + if (constructor == nullptr) { + SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio"); + return nullptr; + } + + jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(), + denoised.samples.data()); + + return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl( + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, + jint sample_rate) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); + + env->ReleaseStringUTFChars(filename, p_filename); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ok; +} diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index e3441525..bb609cab 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/offline-tts.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/wave-writer.h" #include "sherpa-onnx/jni/common.h" namespace sherpa_onnx { @@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( return obj_arr; } + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, + jint sample_rate) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); + + env->ReleaseStringUTFChars(filename, p_filename); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ok; +} diff --git a/sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt b/sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt new file mode 100644 index 00000000..f5e66f20 --- /dev/null +++ b/sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt @@ -0,0 +1,82 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class OfflineSpeechDenoiserGtcrnModelConfig( + var model: String = "", +) + +data class OfflineSpeechDenoiserModelConfig( + var gtcrn: OfflineSpeechDenoiserGtcrnModelConfig = OfflineSpeechDenoiserGtcrnModelConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class OfflineSpeechDenoiserConfig( + var model: OfflineSpeechDenoiserModelConfig = OfflineSpeechDenoiserModelConfig(), +) + +class DenoisedAudio( + val samples: FloatArray, + val sampleRate: Int, +) { + fun save(filename: String) = + saveImpl(filename = filename, samples = samples, sampleRate = sampleRate) + + private external fun saveImpl( + filename: String, + samples: FloatArray, + sampleRate: Int + ): Boolean +} + +class OfflineSpeechDenoiser( + assetManager: AssetManager? = null, + config: OfflineSpeechDenoiserConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun run(samples: FloatArray, sampleRate: Int) = run(ptr, samples, sampleRate) + + val sampleRate + get() = getSampleRate(ptr) + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineSpeechDenoiserConfig, + ): Long + + private external fun newFromFile( + config: OfflineSpeechDenoiserConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun run(ptr: Long, samples: FloatArray, sampleRate: Int): DenoisedAudio + + private external fun getSampleRate(ptr: Long): Int + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +}