Add Kotlin API for speech enhancement GTCRN models (#2008)
This commit is contained in:
1
kotlin-api-examples/OfflineSpeechDenoiser.kt
Symbolic link
1
kotlin-api-examples/OfflineSpeechDenoiser.kt
Symbolic link
@@ -0,0 +1 @@
|
||||
../sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt
|
||||
@@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() {
|
||||
java -Djava.library.path=../build/lib -jar $out_filename
|
||||
}
|
||||
|
||||
function testOfflineSpeechDenoiser() {
|
||||
if [ ! -f ./gtcrn_simple.onnx ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
|
||||
fi
|
||||
|
||||
if [ ! -f ./inp_16k.wav ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
|
||||
fi
|
||||
|
||||
out_filename=test_offline_speech_denoiser.jar
|
||||
kotlinc-jvm -include-runtime -d $out_filename \
|
||||
test_offline_speech_denoiser.kt \
|
||||
OfflineSpeechDenoiser.kt \
|
||||
WaveReader.kt \
|
||||
faked-asset-manager.kt \
|
||||
faked-log.kt
|
||||
|
||||
ls -lh $out_filename
|
||||
|
||||
java -Djava.library.path=../build/lib -jar $out_filename
|
||||
|
||||
ls -lh *.wav
|
||||
}
|
||||
|
||||
testOfflineSpeechDenoiser
|
||||
testOfflineSpeakerDiarization
|
||||
testSpeakerEmbeddingExtractor
|
||||
testOnlineAsr
|
||||
|
||||
41
kotlin-api-examples/test_offline_speech_denoiser.kt
Normal file
41
kotlin-api-examples/test_offline_speech_denoiser.kt
Normal file
@@ -0,0 +1,41 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
// Please download test files in this script from
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
|
||||
|
||||
fun main() {
|
||||
test()
|
||||
}
|
||||
|
||||
fun test() {
|
||||
val denoiser = createOfflineSpeechDenoiser()
|
||||
|
||||
val waveFilename = "./inp_16k.wav";
|
||||
|
||||
val objArray = WaveReader.readWaveFromFile(
|
||||
filename = waveFilename,
|
||||
)
|
||||
val samples: FloatArray = objArray[0] as FloatArray
|
||||
val sampleRate: Int = objArray[1] as Int
|
||||
|
||||
val denoised = denoiser.run(samples, sampleRate);
|
||||
denoised.save(filename="./enhanced-16k.wav")
|
||||
println("saved to ./enhanced-16k.wav")
|
||||
}
|
||||
|
||||
fun createOfflineSpeechDenoiser(): OfflineSpeechDenoiser {
|
||||
val config = OfflineSpeechDenoiserConfig(
|
||||
model = OfflineSpeechDenoiserModelConfig(
|
||||
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(
|
||||
model = "./gtcrn_simple.onnx"
|
||||
),
|
||||
provider = "cpu",
|
||||
numThreads = 1,
|
||||
),
|
||||
)
|
||||
|
||||
println(config)
|
||||
|
||||
return OfflineSpeechDenoiser(config = config)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ set(sources
|
||||
keyword-spotter.cc
|
||||
offline-punctuation.cc
|
||||
offline-recognizer.cc
|
||||
offline-speech-denoiser.cc
|
||||
offline-stream.cc
|
||||
online-punctuation.cc
|
||||
online-recognizer.cc
|
||||
|
||||
@@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) {
|
||||
return env->NewObject(cls, constructor, value);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
|
||||
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
|
||||
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
#if 0
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
|
||||
158
sherpa-onnx/jni/offline-speech-denoiser.cc
Normal file
158
sherpa-onnx/jni/offline-speech-denoiser.cc
Normal file
@@ -0,0 +1,158 @@
|
||||
// sherpa-onnx/jni/offline-speech-denoiser.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig(
|
||||
JNIEnv *env, jobject config) {
|
||||
OfflineSpeechDenoiserConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
fid = env->GetFieldID(
|
||||
cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;");
|
||||
jobject model = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model);
|
||||
|
||||
fid = env->GetFieldID(
|
||||
model_config_cls, "gtcrn",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;");
|
||||
jobject gtcrn = env->GetObjectField(model, fid);
|
||||
jclass gtcrn_cls = env->GetObjectClass(gtcrn);
|
||||
|
||||
fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(gtcrn, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.gtcrn.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model.num_threads = env->GetIntField(model, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model.debug = env->GetBooleanField(model, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)speech_denoiser;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject _config) {
|
||||
return SafeJNI(
|
||||
env, "OfflineSpeechDenoiser_newFromFile",
|
||||
[&]() -> jlong {
|
||||
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
}
|
||||
|
||||
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config);
|
||||
return reinterpret_cast<jlong>(speech_denoiser);
|
||||
},
|
||||
0L);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete(
|
||||
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
return reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr)
|
||||
->GetSampleRate();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
auto speech_denoiser =
|
||||
reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
auto denoised = speech_denoiser->Run(p, n, sample_rate);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio");
|
||||
if (cls == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// https://javap.yawk.at/
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "([FI)V");
|
||||
if (constructor == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size());
|
||||
env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(),
|
||||
denoised.samples.data());
|
||||
|
||||
return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl(
|
||||
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
|
||||
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
return ok;
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
|
||||
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
|
||||
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
82
sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt
Normal file
82
sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt
Normal file
@@ -0,0 +1,82 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OfflineSpeechDenoiserGtcrnModelConfig(
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class OfflineSpeechDenoiserModelConfig(
|
||||
var gtcrn: OfflineSpeechDenoiserGtcrnModelConfig = OfflineSpeechDenoiserGtcrnModelConfig(),
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class OfflineSpeechDenoiserConfig(
|
||||
var model: OfflineSpeechDenoiserModelConfig = OfflineSpeechDenoiserModelConfig(),
|
||||
)
|
||||
|
||||
class DenoisedAudio(
|
||||
val samples: FloatArray,
|
||||
val sampleRate: Int,
|
||||
) {
|
||||
fun save(filename: String) =
|
||||
saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
|
||||
|
||||
private external fun saveImpl(
|
||||
filename: String,
|
||||
samples: FloatArray,
|
||||
sampleRate: Int
|
||||
): Boolean
|
||||
}
|
||||
|
||||
class OfflineSpeechDenoiser(
|
||||
assetManager: AssetManager? = null,
|
||||
config: OfflineSpeechDenoiserConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun run(samples: FloatArray, sampleRate: Int) = run(ptr, samples, sampleRate)
|
||||
|
||||
val sampleRate
|
||||
get() = getSampleRate(ptr)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OfflineSpeechDenoiserConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OfflineSpeechDenoiserConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun run(ptr: Long, samples: FloatArray, sampleRate: Int): DenoisedAudio
|
||||
|
||||
private external fun getSampleRate(ptr: Long): Int
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user