diff --git a/kotlin-api-examples/AudioTagging.kt b/kotlin-api-examples/AudioTagging.kt new file mode 100644 index 00000000..621cfb55 --- /dev/null +++ b/kotlin-api-examples/AudioTagging.kt @@ -0,0 +1,95 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager +import android.util.Log + +private val TAG = "sherpa-onnx" + +data class OfflineZipformerAudioTaggingModelConfig ( + val model: String, +) + +data class AudioTaggingModelConfig ( + var zipformer: OfflineZipformerAudioTaggingModelConfig, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class AudioTaggingConfig ( + var model: AudioTaggingModelConfig, + var labels: String, + var topK: Int = 5, +) + +data class AudioEvent ( + val name: String, + val index: Int, + val prob: Float, +) + +class AudioTagging( + assetManager: AssetManager? = null, + config: AudioTaggingConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if(ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + fun compute(stream: OfflineStream, topK: Int=-1): ArrayList { + var events :Array = compute(ptr, stream.ptr, topK) + val ans = ArrayList() + + for (e in events) { + val p :Array = e as Array + ans.add(AudioEvent( + name=p[0] as String, + index=p[1] as Int, + prob=p[2] as Float, + )) + } + + return ans + } + + private external fun newFromAsset( + assetManager: AssetManager, + config: AudioTaggingConfig, + ): Long + + private external fun newFromFile( + config: AudioTaggingConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index 4402c71c..bc82c699 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -7,12 +7,56 @@ fun callback(samples: FloatArray): Unit { } fun main() { + testAudioTagging() testSpeakerRecognition() testTts() testAsr("transducer") testAsr("zipformer2-ctc") } +fun testAudioTagging() { + val config = AudioTaggingConfig( + model=AudioTaggingModelConfig( + zipformer=OfflineZipformerAudioTaggingModelConfig( + model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx", + ), + numThreads=1, + debug=true, + provider="cpu", + ), + labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv", + topK=5, + ) + val tagger = AudioTagging(assetManager=null, config=config) + + val testFiles = arrayOf( + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav", + ) + println("----------") + for (waveFilename in testFiles) { + val stream = tagger.createStream() + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + stream.acceptWaveform(samples, sampleRate = sampleRate) + val events = tagger.compute(stream) + stream.release() + + println(waveFilename) + println(events) + println("----------") + } + + tagger.release() +} + fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { var objArray = WaveReader.readWaveFromFile( filename = filename, diff --git a/kotlin-api-examples/OfflineStream.kt b/kotlin-api-examples/OfflineStream.kt new file mode 100644 index 00000000..a4e650f8 --- /dev/null +++ b/kotlin-api-examples/OfflineStream.kt @@ -0,0 +1,24 @@ +package com.k2fsa.sherpa.onnx + +class OfflineStream(var ptr: Long) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + protected fun finalize() { + if(ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun delete(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 0a0b38d8..750a70e5 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -4,8 +4,7 @@ # Note: This scripts runs only on Linux and macOS, though sherpa-onnx # supports building JNI libs for Windows. -set -e - +set -ex cd .. mkdir -p build @@ -29,59 +28,93 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH cd ../kotlin-api-examples -if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx -fi +function testSpeakerEmbeddingExtractor() { + if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx + fi -if [ ! -f ./speaker1_a_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav -fi + if [ ! -f ./speaker1_a_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav + fi -if [ ! -f ./speaker1_b_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav -fi + if [ ! -f ./speaker1_b_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav + fi -if [ ! -f ./speaker2_a_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav -fi + if [ ! -f ./speaker2_a_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav + fi +} -if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then - git lfs install - git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 -fi +function testAsr() { + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then + git lfs install + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 + fi -if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 - tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 - rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 -fi + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + fi +} -if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 - tar xf vits-piper-en_US-amy-low.tar.bz2 - rm vits-piper-en_US-amy-low.tar.bz2 -fi +function testTts() { + if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 + tar xf vits-piper-en_US-amy-low.tar.bz2 + rm vits-piper-en_US-amy-low.tar.bz2 + fi +} -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt Speaker.kt faked-log.kt +function testAudioTagging() { + if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + fi +} + +function test() { + testAudioTagging + testSpeakerEmbeddingExtractor + testAsr + testTts +} + +test + +kotlinc-jvm -include-runtime -d main.jar \ + AudioTagging.kt \ + Main.kt \ + OfflineStream.kt \ + SherpaOnnx.kt \ + Speaker.kt \ + Tts.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt ls -lh main.jar java -Djava.library.path=../build/lib -jar main.jar -# For two-pass +function testTwoPass() { + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + fi -if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 - tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 - rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 -fi + if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 + tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 + rm sherpa-onnx-whisper-tiny.en.tar.bz2 + fi -if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 - tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 - rm sherpa-onnx-whisper-tiny.en.tar.bz2 -fi + kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt + ls -lh 2pass.jar + java -Djava.library.path=../build/lib -jar 2pass.jar +} -kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt -ls -lh 2pass.jar -java -Djava.library.path=../build/lib -jar 2pass.jar +testTwoPass diff --git a/sherpa-onnx/csrc/audio-tagging-impl.cc b/sherpa-onnx/csrc/audio-tagging-impl.cc index 33e8dbb7..37cd6faa 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.cc +++ b/sherpa-onnx/csrc/audio-tagging-impl.cc @@ -4,6 +4,11 @@ #include "sherpa-onnx/csrc/audio-tagging-impl.h" +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" #include "sherpa-onnx/csrc/macros.h" @@ -20,4 +25,17 @@ std::unique_ptr AudioTaggingImpl::Create( return nullptr; } +#if __ANDROID_API__ >= 9 +std::unique_ptr AudioTaggingImpl::Create( + AAssetManager *mgr, const AudioTaggingConfig &config) { + if (!config.model.zipformer.model.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOG( + "Please specify an audio tagging model! Return a null pointer"); + return nullptr; +} +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/audio-tagging-impl.h b/sherpa-onnx/csrc/audio-tagging-impl.h index e5e19245..ac6f2e50 100644 --- a/sherpa-onnx/csrc/audio-tagging-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-impl.h @@ -7,6 +7,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/audio-tagging.h" namespace sherpa_onnx { @@ -18,6 +23,11 @@ class AudioTaggingImpl { static std::unique_ptr Create( const AudioTaggingConfig &config); +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const AudioTaggingConfig &config); +#endif + virtual std::unique_ptr CreateStream() const = 0; virtual std::vector Compute(OfflineStream *s, diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.cc b/sherpa-onnx/csrc/audio-tagging-label-file.cc index 24846a17..f81e9670 100644 --- a/sherpa-onnx/csrc/audio-tagging-label-file.cc +++ b/sherpa-onnx/csrc/audio-tagging-label-file.cc @@ -8,7 +8,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -18,6 +26,15 @@ AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { Init(is); } +#if __ANDROID_API__ >= 9 +AudioTaggingLabels::AudioTaggingLabels(AAssetManager *mgr, + const std::string &filename) { + auto buf = ReadFile(mgr, filename); + std::istrstream is(buf.data(), buf.size()); + Init(is); +} +#endif + // Format of a label file /* index,mid,display_name diff --git a/sherpa-onnx/csrc/audio-tagging-label-file.h b/sherpa-onnx/csrc/audio-tagging-label-file.h index 9e71557f..c366972e 100644 --- a/sherpa-onnx/csrc/audio-tagging-label-file.h +++ b/sherpa-onnx/csrc/audio-tagging-label-file.h @@ -8,11 +8,19 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + namespace sherpa_onnx { class AudioTaggingLabels { public: explicit AudioTaggingLabels(const std::string &filename); +#if __ANDROID_API__ >= 9 + AudioTaggingLabels(AAssetManager *mgr, const std::string &filename); +#endif // Return the event name for the given index. // The returned reference is valid as long as this object is alive diff --git a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h index 639f644c..65870dcc 100644 --- a/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h +++ b/sherpa-onnx/csrc/audio-tagging-zipformer-impl.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/audio-tagging-impl.h" #include "sherpa-onnx/csrc/audio-tagging-label-file.h" #include "sherpa-onnx/csrc/audio-tagging.h" @@ -28,6 +33,20 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl { } } +#if __ANDROID_API__ >= 9 + explicit AudioTaggingZipformerImpl(AAssetManager *mgr, + const AudioTaggingConfig &config) + : config_(config), + model_(mgr, config.model), + labels_(mgr, config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } +#endif + std::unique_ptr CreateStream() const override { return std::make_unique(); } diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index 34d558dd..8fcb6ef4 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -4,6 +4,11 @@ #include "sherpa-onnx/csrc/audio-tagging.h" +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/audio-tagging-impl.h" #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" @@ -61,6 +66,11 @@ std::string AudioTaggingConfig::ToString() const { AudioTagging::AudioTagging(const AudioTaggingConfig &config) : impl_(AudioTaggingImpl::Create(config)) {} +#if __ANDROID_API__ >= 9 +AudioTagging::AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config) + : impl_(AudioTaggingImpl::Create(mgr, config)) {} +#endif + AudioTagging::~AudioTagging() = default; std::unique_ptr AudioTagging::CreateStream() const { diff --git a/sherpa-onnx/csrc/audio-tagging.h b/sherpa-onnx/csrc/audio-tagging.h index 50cfea02..6f68e90f 100644 --- a/sherpa-onnx/csrc/audio-tagging.h +++ b/sherpa-onnx/csrc/audio-tagging.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-stream.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -46,6 +51,10 @@ class AudioTagging { public: explicit AudioTagging(const AudioTaggingConfig &config); +#if __ANDROID_API__ >= 9 + AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config); +#endif + ~AudioTagging(); std::unique_ptr CreateStream() const; diff --git a/sherpa-onnx/jni/AudioTagging.kt b/sherpa-onnx/jni/AudioTagging.kt new file mode 100644 index 00000000..f3d82779 --- /dev/null +++ b/sherpa-onnx/jni/AudioTagging.kt @@ -0,0 +1,84 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager +import android.util.Log + +private val TAG = "sherpa-onnx" + +data class OfflineZipformerAudioTaggingModelConfig ( + val model: String, +) + +data class AudioTaggingModelConfig ( + var zipformer: OfflineZipformerAudioTaggingModelConfig, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class AudioTaggingConfig ( + var model: AudioTaggingModelConfig, + var labels: String, + var topK: Int = 5, +) + +data class AudioEvent ( + val name: String, + val index: Int, + val prob: Float, +) + +class AudioTagging( + assetManager: AssetManager? = null, + config: AudioTaggingConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if(ptr != 0) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + // fun compute(stream: OfflineStream, topK: Int=-1): Array { + fun compute(stream: OfflineStream, topK: Int=-1): Array { + var events :Array = compute(ptr, stream.ptr, topK) + } + + private external fun newFromAsset( + assetManager: AssetManager, + config: AudioTaggingConfig, + ): Long + + private external fun newFromFile( + config: AudioTaggingConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index 29e98211..75b6a1bb 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -9,6 +9,10 @@ if(NOT DEFINED ANDROID_ABI) include_directories($ENV{JAVA_HOME}/include/darwin) endif() -add_library(sherpa-onnx-jni jni.cc) +add_library(sherpa-onnx-jni + audio-tagging.cc + jni.cc + offline-stream.cc +) target_link_libraries(sherpa-onnx-jni sherpa-onnx-core) install(TARGETS sherpa-onnx-jni DESTINATION lib) diff --git a/sherpa-onnx/jni/audio-tagging.cc b/sherpa-onnx/jni/audio-tagging.cc new file mode 100644 index 00000000..89fde8e5 --- /dev/null +++ b/sherpa-onnx/jni/audio-tagging.cc @@ -0,0 +1,126 @@ +// sherpa-onnx/jni/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/audio-tagging.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) { + AudioTaggingConfig ans; + + jclass cls = env->GetObjectClass(config); + + jfieldID fid = env->GetFieldID( + cls, "model", "Lcom/k2fsa/sherpa/onnx/AudioTaggingModelConfig;"); + jobject model = env->GetObjectField(config, fid); + jclass model_cls = env->GetObjectClass(model); + + fid = env->GetFieldID( + model_cls, "zipformer", + "Lcom/k2fsa/sherpa/onnx/OfflineZipformerAudioTaggingModelConfig;"); + jobject zipformer = env->GetObjectField(model, fid); + jclass zipformer_cls = env->GetObjectClass(zipformer); + + fid = env->GetFieldID(zipformer_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(zipformer, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.zipformer.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model, fid); + + fid = env->GetFieldID(model_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model, fid); + + fid = env->GetFieldID(model_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "labels", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.labels = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "topK", "I"); + ans.top_k = env->GetIntField(config, fid); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetAudioTaggingConfig(env, _config); + SHERPA_ONNX_LOGE("audio tagging newFromFile config:\n%s", + config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto tagger = new sherpa_onnx::AudioTagging(config); + + return (jlong)tagger; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto tagger = reinterpret_cast(ptr); + std::unique_ptr s = tagger->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_onnx::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_compute( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr, jint top_k) { + auto tagger = reinterpret_cast(ptr); + auto stream = reinterpret_cast(streamPtr); + std::vector events = tagger->Compute(stream, top_k); + + // TODO(fangjun): Return an array of AudioEvent directly + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + events.size(), env->FindClass("java/lang/Object"), nullptr); + + int32_t i = 0; + for (const auto &e : events) { + jobjectArray a = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + // 0 name + // 1 index + // 2 prob + jstring js = env->NewStringUTF(e.name.c_str()); + env->SetObjectArrayElement(a, 0, js); + env->SetObjectArrayElement(a, 1, NewInteger(env, e.index)); + env->SetObjectArrayElement(a, 2, NewFloat(env, e.prob)); + + env->SetObjectArrayElement(obj_arr, i, a); + i += 1; + } + + return obj_arr; +} diff --git a/sherpa-onnx/jni/common.h b/sherpa-onnx/jni/common.h new file mode 100644 index 00000000..d06350f8 --- /dev/null +++ b/sherpa-onnx/jni/common.h @@ -0,0 +1,23 @@ +// sherpa-onnx/jni/common.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_JNI_COMMON_H_ +#define SHERPA_ONNX_JNI_COMMON_H_ + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +// If you use ndk, you can find "jni.h" inside +// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include +#include "jni.h" // NOLINT + +#define SHERPA_ONNX_EXTERN_C extern "C" + +// defined in jni.cc +jobject NewInteger(JNIEnv *env, int32_t value); +jobject NewFloat(JNIEnv *env, float value); + +#endif // SHERPA_ONNX_JNI_COMMON_H_ diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 23596e97..6bb25a36 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -7,20 +7,11 @@ // TODO(fangjun): Add documentation to functions/methods in this file // and also show how to use them with kotlin, possibly with java. -// If you use ndk, you can find "jni.h" inside -// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -#include "jni.h" // NOLINT - #include #include #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" @@ -31,13 +22,12 @@ #include "sherpa-onnx/csrc/voice-activity-detector.h" #include "sherpa-onnx/csrc/wave-reader.h" #include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/jni/common.h" #if SHERPA_ONNX_ENABLE_TTS == 1 #include "sherpa-onnx/csrc/offline-tts.h" #endif -#define SHERPA_ONNX_EXTERN_C extern "C" - namespace sherpa_onnx { class SherpaOnnx { @@ -1224,12 +1214,18 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames( // see // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables -static jobject NewInteger(JNIEnv *env, int32_t value) { +jobject NewInteger(JNIEnv *env, int32_t value) { jclass cls = env->FindClass("java/lang/Integer"); jmethodID constructor = env->GetMethodID(cls, "", "(I)V"); return env->NewObject(cls, constructor, value); } +jobject NewFloat(JNIEnv *env, float value) { + jclass cls = env->FindClass("java/lang/Float"); + jmethodID constructor = env->GetMethodID(cls, "", "(F)V"); + return env->NewObject(cls, constructor, value); +} + #if SHERPA_ONNX_ENABLE_TTS == 1 SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new( diff --git a/sherpa-onnx/jni/offline-stream.cc b/sherpa-onnx/jni/offline-stream.cc new file mode 100644 index 00000000..a2644d25 --- /dev/null +++ b/sherpa-onnx/jni/offline-stream.cc @@ -0,0 +1,25 @@ +// sherpa-onnx/jni/offline-stream.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-stream.h" + +#include "sherpa-onnx/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto stream = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + stream->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +}