Add JNI support for spoken language identification (#782)
This commit is contained in:
@@ -1,84 +0,0 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
private val TAG = "sherpa-onnx"
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig (
|
||||
val model: String,
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig (
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class AudioTaggingConfig (
|
||||
var model: AudioTaggingModelConfig,
|
||||
var labels: String,
|
||||
var topK: Int = 5,
|
||||
)
|
||||
|
||||
data class AudioEvent (
|
||||
val name: String,
|
||||
val index: Int,
|
||||
val prob: Float,
|
||||
)
|
||||
|
||||
class AudioTagging(
|
||||
assetManager: AssetManager? = null,
|
||||
config: AudioTaggingConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if(ptr != 0) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
|
||||
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
|
||||
var events :Array<Any> = compute(ptr, stream.ptr, topK)
|
||||
}
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ add_library(sherpa-onnx-jni
|
||||
audio-tagging.cc
|
||||
jni.cc
|
||||
offline-stream.cc
|
||||
spoken-language-identification.cc
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
|
||||
install(TARGETS sherpa-onnx-jni DESTINATION lib)
|
||||
|
||||
104
sherpa-onnx/jni/spoken-language-identification.cc
Normal file
104
sherpa-onnx/jni/spoken-language-identification.cc
Normal file
@@ -0,0 +1,104 @@
|
||||
// sherpa-onnx/jni/spoken-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
|
||||
JNIEnv *env, jobject config) {
|
||||
SpokenLanguageIdentificationConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid = env->GetFieldID(
|
||||
cls, "whisper",
|
||||
"Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;");
|
||||
|
||||
jobject whisper = env->GetObjectField(config, fid);
|
||||
jclass whisper_cls = env->GetObjectClass(whisper);
|
||||
|
||||
fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;");
|
||||
|
||||
jstring s = (jstring)env->GetObjectField(whisper, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.whisper.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(whisper, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.whisper.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_cls, "tailPaddings", "I");
|
||||
ans.whisper.tail_paddings = env->GetIntField(whisper, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "numThreads", "I");
|
||||
ans.num_threads = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "debug", "Z");
|
||||
ans.debug = env->GetBooleanField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config =
|
||||
sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s",
|
||||
config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config);
|
||||
|
||||
return (jlong)tagger;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto slid =
|
||||
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
|
||||
std::unique_ptr<sherpa_onnx::OfflineStream> s = slid->CreateStream();
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
|
||||
// ./offline-stream.cc
|
||||
sherpa_onnx::OfflineStream *p = s.release();
|
||||
return (jlong)p;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jlong s_ptr) {
|
||||
sherpa_onnx::SpokenLanguageIdentification *slid =
|
||||
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
|
||||
sherpa_onnx::OfflineStream *s =
|
||||
reinterpret_cast<sherpa_onnx::OfflineStream *>(s_ptr);
|
||||
std::string lang = slid->Compute(s);
|
||||
return env->NewStringUTF(lang.c_str());
|
||||
}
|
||||
Reference in New Issue
Block a user