Add Kotlin API for audio tagging (#770)

This commit is contained in:
Fangjun Kuang
2024-04-15 13:49:35 +08:00
committed by GitHub
parent 13730ecbd8
commit 5981adf454
17 changed files with 601 additions and 56 deletions

View File

@@ -0,0 +1,84 @@
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig (
val model: String,
)
data class AudioTaggingModelConfig (
var zipformer: OfflineZipformerAudioTaggingModelConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class AudioTaggingConfig (
var model: AudioTaggingModelConfig,
var labels: String,
var topK: Int = 5,
)
data class AudioEvent (
val name: String,
val index: Int,
val prob: Float,
)
class AudioTagging(
assetManager: AssetManager? = null,
config: AudioTaggingConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
if(ptr != 0) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun createStream(): OfflineStream {
val p = createStream(ptr)
return OfflineStream(p)
}
// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
var events :Array<Any> = compute(ptr, stream.ptr, topK)
}
private external fun newFromAsset(
assetManager: AssetManager,
config: AudioTaggingConfig,
): Long
private external fun newFromFile(
config: AudioTaggingConfig,
): Long
private external fun delete(ptr: Long)
private external fun createStream(ptr: Long): Long
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}

View File

@@ -9,6 +9,10 @@ if(NOT DEFINED ANDROID_ABI)
include_directories($ENV{JAVA_HOME}/include/darwin)
endif()
add_library(sherpa-onnx-jni jni.cc)
add_library(sherpa-onnx-jni
audio-tagging.cc
jni.cc
offline-stream.cc
)
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)

View File

@@ -0,0 +1,126 @@
// sherpa-onnx/jni/audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
AudioTaggingConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(
cls, "model", "Lcom/k2fsa/sherpa/onnx/AudioTaggingModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_cls = env->GetObjectClass(model);
fid = env->GetFieldID(
model_cls, "zipformer",
"Lcom/k2fsa/sherpa/onnx/OfflineZipformerAudioTaggingModelConfig;");
jobject zipformer = env->GetObjectField(model, fid);
jclass zipformer_cls = env->GetObjectClass(zipformer);
fid = env->GetFieldID(zipformer_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(zipformer, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.zipformer.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "labels", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.labels = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "topK", "I");
ans.top_k = env->GetIntField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetAudioTaggingConfig(env, _config);
SHERPA_ONNX_LOGE("audio tagging newFromFile config:\n%s",
config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto tagger = new sherpa_onnx::AudioTagging(config);
return (jlong)tagger;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
std::unique_ptr<sherpa_onnx::OfflineStream> s = tagger->CreateStream();
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OfflineStream *p = s.release();
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_compute(
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr, jint top_k) {
auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
std::vector<sherpa_onnx::AudioEvent> events = tagger->Compute(stream, top_k);
// TODO(fangjun): Return an array of AudioEvent directly
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
events.size(), env->FindClass("java/lang/Object"), nullptr);
int32_t i = 0;
for (const auto &e : events) {
jobjectArray a = (jobjectArray)env->NewObjectArray(
3, env->FindClass("java/lang/Object"), nullptr);
// 0 name
// 1 index
// 2 prob
jstring js = env->NewStringUTF(e.name.c_str());
env->SetObjectArrayElement(a, 0, js);
env->SetObjectArrayElement(a, 1, NewInteger(env, e.index));
env->SetObjectArrayElement(a, 2, NewFloat(env, e.prob));
env->SetObjectArrayElement(obj_arr, i, a);
i += 1;
}
return obj_arr;
}

23
sherpa-onnx/jni/common.h Normal file
View File

@@ -0,0 +1,23 @@
// sherpa-onnx/jni/common.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_JNI_COMMON_H_
#define SHERPA_ONNX_JNI_COMMON_H_
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
// If you use ndk, you can find "jni.h" inside
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT
#define SHERPA_ONNX_EXTERN_C extern "C"
// defined in jni.cc
jobject NewInteger(JNIEnv *env, int32_t value);
jobject NewFloat(JNIEnv *env, float value);
#endif // SHERPA_ONNX_JNI_COMMON_H_

View File

@@ -7,20 +7,11 @@
// TODO(fangjun): Add documentation to functions/methods in this file
// and also show how to use them with kotlin, possibly with java.
// If you use ndk, you can find "jni.h" inside
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT
#include <fstream>
#include <functional>
#include <strstream>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -31,13 +22,12 @@
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
#if SHERPA_ONNX_ENABLE_TTS == 1
#include "sherpa-onnx/csrc/offline-tts.h"
#endif
#define SHERPA_ONNX_EXTERN_C extern "C"
namespace sherpa_onnx {
class SherpaOnnx {
@@ -1224,12 +1214,18 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
jobject NewFloat(JNIEnv *env, float value) {
jclass cls = env->FindClass("java/lang/Float");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
return env->NewObject(cls, constructor, value);
}
#if SHERPA_ONNX_ENABLE_TTS == 1
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(

View File

@@ -0,0 +1,25 @@
// sherpa-onnx/jni/offline-stream.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/jni/common.h"
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
stream->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}