Add Kotlin API for speech enhancement GTCRN models (#2008)

This commit is contained in:
Fangjun Kuang
2025-03-16 10:41:01 +08:00
committed by GitHub
parent c972554ad1
commit ed8e6c9aed
8 changed files with 326 additions and 17 deletions

View File

@@ -16,6 +16,7 @@ set(sources
keyword-spotter.cc
offline-punctuation.cc
offline-recognizer.cc
offline-speech-denoiser.cc
offline-stream.cc
online-punctuation.cc
online-recognizer.cc

View File

@@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) {
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return ok;
}
#if 0
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL

View File

@@ -0,0 +1,158 @@
// sherpa-onnx/jni/offline-speech-denoiser.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig(
JNIEnv *env, jobject config) {
OfflineSpeechDenoiserConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(
cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
fid = env->GetFieldID(
model_config_cls, "gtcrn",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;");
jobject gtcrn = env->GetObjectField(model, fid);
jclass gtcrn_cls = env->GetObjectClass(gtcrn);
fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(gtcrn, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.gtcrn.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)speech_denoiser;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
return SafeJNI(
env, "OfflineSpeechDenoiser_newFromFile",
[&]() -> jlong {
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config);
return reinterpret_cast<jlong>(speech_denoiser);
},
0L);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr)
->GetSampleRate();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto speech_denoiser =
reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto denoised = speech_denoiser->Run(p, n, sample_rate);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio");
if (cls == nullptr) {
SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio");
return nullptr;
}
// https://javap.yawk.at/
jmethodID constructor = env->GetMethodID(cls, "<init>", "([FI)V");
if (constructor == nullptr) {
SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio");
return nullptr;
}
jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(),
denoised.samples.data());
return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return ok;
}

View File

@@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
@@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return ok;
}

View File

@@ -0,0 +1,82 @@
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineSpeechDenoiserGtcrnModelConfig(
var model: String = "",
)
data class OfflineSpeechDenoiserModelConfig(
var gtcrn: OfflineSpeechDenoiserGtcrnModelConfig = OfflineSpeechDenoiserGtcrnModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OfflineSpeechDenoiserConfig(
var model: OfflineSpeechDenoiserModelConfig = OfflineSpeechDenoiserModelConfig(),
)
class DenoisedAudio(
val samples: FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) =
saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}
class OfflineSpeechDenoiser(
assetManager: AssetManager? = null,
config: OfflineSpeechDenoiserConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun run(samples: FloatArray, sampleRate: Int) = run(ptr, samples, sampleRate)
val sampleRate
get() = getSampleRate(ptr)
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineSpeechDenoiserConfig,
): Long
private external fun newFromFile(
config: OfflineSpeechDenoiserConfig,
): Long
private external fun delete(ptr: Long)
private external fun run(ptr: Long, samples: FloatArray, sampleRate: Int): DenoisedAudio
private external fun getSampleRate(ptr: Long): Int
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}