Add Kotlin API for audio tagging (#770)
This commit is contained in:
95
kotlin-api-examples/AudioTagging.kt
Normal file
95
kotlin-api-examples/AudioTagging.kt
Normal file
@@ -0,0 +1,95 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
private val TAG = "sherpa-onnx"
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig (
|
||||
val model: String,
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig (
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class AudioTaggingConfig (
|
||||
var model: AudioTaggingModelConfig,
|
||||
var labels: String,
|
||||
var topK: Int = 5,
|
||||
)
|
||||
|
||||
data class AudioEvent (
|
||||
val name: String,
|
||||
val index: Int,
|
||||
val prob: Float,
|
||||
)
|
||||
|
||||
class AudioTagging(
|
||||
assetManager: AssetManager? = null,
|
||||
config: AudioTaggingConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if(ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
fun compute(stream: OfflineStream, topK: Int=-1): ArrayList<AudioEvent> {
|
||||
var events :Array<Any> = compute(ptr, stream.ptr, topK)
|
||||
val ans = ArrayList<AudioEvent>()
|
||||
|
||||
for (e in events) {
|
||||
val p :Array<Any> = e as Array<Any>
|
||||
ans.add(AudioEvent(
|
||||
name=p[0] as String,
|
||||
index=p[1] as Int,
|
||||
prob=p[2] as Float,
|
||||
))
|
||||
}
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,56 @@ fun callback(samples: FloatArray): Unit {
|
||||
}
|
||||
|
||||
fun main() {
|
||||
testAudioTagging()
|
||||
testSpeakerRecognition()
|
||||
testTts()
|
||||
testAsr("transducer")
|
||||
testAsr("zipformer2-ctc")
|
||||
}
|
||||
|
||||
fun testAudioTagging() {
|
||||
val config = AudioTaggingConfig(
|
||||
model=AudioTaggingModelConfig(
|
||||
zipformer=OfflineZipformerAudioTaggingModelConfig(
|
||||
model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
|
||||
),
|
||||
numThreads=1,
|
||||
debug=true,
|
||||
provider="cpu",
|
||||
),
|
||||
labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
|
||||
topK=5,
|
||||
)
|
||||
val tagger = AudioTagging(assetManager=null, config=config)
|
||||
|
||||
val testFiles = arrayOf(
|
||||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
|
||||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
|
||||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
|
||||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
|
||||
)
|
||||
println("----------")
|
||||
for (waveFilename in testFiles) {
|
||||
val stream = tagger.createStream()
|
||||
|
||||
val objArray = WaveReader.readWaveFromFile(
|
||||
filename = waveFilename,
|
||||
)
|
||||
val samples: FloatArray = objArray[0] as FloatArray
|
||||
val sampleRate: Int = objArray[1] as Int
|
||||
|
||||
stream.acceptWaveform(samples, sampleRate = sampleRate)
|
||||
val events = tagger.compute(stream)
|
||||
stream.release()
|
||||
|
||||
println(waveFilename)
|
||||
println(events)
|
||||
println("----------")
|
||||
}
|
||||
|
||||
tagger.release()
|
||||
}
|
||||
|
||||
fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
|
||||
var objArray = WaveReader.readWaveFromFile(
|
||||
filename = filename,
|
||||
|
||||
24
kotlin-api-examples/OfflineStream.kt
Normal file
24
kotlin-api-examples/OfflineStream.kt
Normal file
@@ -0,0 +1,24 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
class OfflineStream(var ptr: Long) {
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
protected fun finalize() {
|
||||
if(ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,8 +4,7 @@
|
||||
# Note: This scripts runs only on Linux and macOS, though sherpa-onnx
|
||||
# supports building JNI libs for Windows.
|
||||
|
||||
set -e
|
||||
|
||||
set -ex
|
||||
|
||||
cd ..
|
||||
mkdir -p build
|
||||
@@ -29,59 +28,93 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
|
||||
|
||||
cd ../kotlin-api-examples
|
||||
|
||||
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
|
||||
fi
|
||||
function testSpeakerEmbeddingExtractor() {
|
||||
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
|
||||
fi
|
||||
|
||||
if [ ! -f ./speaker1_a_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
|
||||
fi
|
||||
if [ ! -f ./speaker1_a_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
|
||||
fi
|
||||
|
||||
if [ ! -f ./speaker1_b_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
|
||||
fi
|
||||
if [ ! -f ./speaker1_b_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
|
||||
fi
|
||||
|
||||
if [ ! -f ./speaker2_a_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
|
||||
fi
|
||||
if [ ! -f ./speaker2_a_cn_16k.wav ]; then
|
||||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
|
||||
fi
|
||||
}
|
||||
|
||||
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
|
||||
fi
|
||||
function testAsr() {
|
||||
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
|
||||
fi
|
||||
|
||||
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
fi
|
||||
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
|
||||
fi
|
||||
}
|
||||
|
||||
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
|
||||
tar xf vits-piper-en_US-amy-low.tar.bz2
|
||||
rm vits-piper-en_US-amy-low.tar.bz2
|
||||
fi
|
||||
function testTts() {
|
||||
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
|
||||
tar xf vits-piper-en_US-amy-low.tar.bz2
|
||||
rm vits-piper-en_US-amy-low.tar.bz2
|
||||
fi
|
||||
}
|
||||
|
||||
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt Speaker.kt faked-log.kt
|
||||
function testAudioTagging() {
|
||||
if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
|
||||
fi
|
||||
}
|
||||
|
||||
function test() {
|
||||
testAudioTagging
|
||||
testSpeakerEmbeddingExtractor
|
||||
testAsr
|
||||
testTts
|
||||
}
|
||||
|
||||
test
|
||||
|
||||
kotlinc-jvm -include-runtime -d main.jar \
|
||||
AudioTagging.kt \
|
||||
Main.kt \
|
||||
OfflineStream.kt \
|
||||
SherpaOnnx.kt \
|
||||
Speaker.kt \
|
||||
Tts.kt \
|
||||
WaveReader.kt \
|
||||
faked-asset-manager.kt \
|
||||
faked-log.kt
|
||||
|
||||
ls -lh main.jar
|
||||
|
||||
java -Djava.library.path=../build/lib -jar main.jar
|
||||
|
||||
# For two-pass
|
||||
function testTwoPass() {
|
||||
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
fi
|
||||
|
||||
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
fi
|
||||
if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
fi
|
||||
|
||||
if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.en.tar.bz2
|
||||
fi
|
||||
kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt
|
||||
ls -lh 2pass.jar
|
||||
java -Djava.library.path=../build/lib -jar 2pass.jar
|
||||
}
|
||||
|
||||
kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt
|
||||
ls -lh 2pass.jar
|
||||
java -Djava.library.path=../build/lib -jar 2pass.jar
|
||||
testTwoPass
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
@@ -20,4 +25,17 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
AAssetManager *mgr, const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
"Please specify an audio tagging model! Return a null pointer");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -18,6 +23,11 @@ class AudioTaggingImpl {
|
||||
static std::unique_ptr<AudioTaggingImpl> Create(
|
||||
const AudioTaggingConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<AudioTaggingImpl> Create(
|
||||
AAssetManager *mgr, const AudioTaggingConfig &config);
|
||||
#endif
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||
|
||||
@@ -8,7 +8,15 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -18,6 +26,15 @@ AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
|
||||
Init(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
AudioTaggingLabels::AudioTaggingLabels(AAssetManager *mgr,
|
||||
const std::string &filename) {
|
||||
auto buf = ReadFile(mgr, filename);
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
Init(is);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Format of a label file
|
||||
/*
|
||||
index,mid,display_name
|
||||
|
||||
@@ -8,11 +8,19 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingLabels {
|
||||
public:
|
||||
explicit AudioTaggingLabels(const std::string &filename);
|
||||
#if __ANDROID_API__ >= 9
|
||||
AudioTaggingLabels(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
// Return the event name for the given index.
|
||||
// The returned reference is valid as long as this object is alive
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
@@ -28,6 +33,20 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit AudioTaggingZipformerImpl(AAssetManager *mgr,
|
||||
const AudioTaggingConfig &config)
|
||||
: config_(config),
|
||||
model_(mgr, config.model),
|
||||
labels_(mgr, config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>();
|
||||
}
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -61,6 +66,11 @@ std::string AudioTaggingConfig::ToString() const {
|
||||
AudioTagging::AudioTagging(const AudioTaggingConfig &config)
|
||||
: impl_(AudioTaggingImpl::Create(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
AudioTagging::AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config)
|
||||
: impl_(AudioTaggingImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
AudioTagging::~AudioTagging() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> AudioTagging::CreateStream() const {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
@@ -46,6 +51,10 @@ class AudioTagging {
|
||||
public:
|
||||
explicit AudioTagging(const AudioTaggingConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config);
|
||||
#endif
|
||||
|
||||
~AudioTagging();
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
84
sherpa-onnx/jni/AudioTagging.kt
Normal file
84
sherpa-onnx/jni/AudioTagging.kt
Normal file
@@ -0,0 +1,84 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
private val TAG = "sherpa-onnx"
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig (
|
||||
val model: String,
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig (
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class AudioTaggingConfig (
|
||||
var model: AudioTaggingModelConfig,
|
||||
var labels: String,
|
||||
var topK: Int = 5,
|
||||
)
|
||||
|
||||
data class AudioEvent (
|
||||
val name: String,
|
||||
val index: Int,
|
||||
val prob: Float,
|
||||
)
|
||||
|
||||
class AudioTagging(
|
||||
assetManager: AssetManager? = null,
|
||||
config: AudioTaggingConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if(ptr != 0) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
|
||||
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
|
||||
var events :Array<Any> = compute(ptr, stream.ptr, topK)
|
||||
}
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,10 @@ if(NOT DEFINED ANDROID_ABI)
|
||||
include_directories($ENV{JAVA_HOME}/include/darwin)
|
||||
endif()
|
||||
|
||||
add_library(sherpa-onnx-jni jni.cc)
|
||||
add_library(sherpa-onnx-jni
|
||||
audio-tagging.cc
|
||||
jni.cc
|
||||
offline-stream.cc
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
|
||||
install(TARGETS sherpa-onnx-jni DESTINATION lib)
|
||||
|
||||
126
sherpa-onnx/jni/audio-tagging.cc
Normal file
126
sherpa-onnx/jni/audio-tagging.cc
Normal file
@@ -0,0 +1,126 @@
|
||||
// sherpa-onnx/jni/audio-tagging.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
|
||||
AudioTaggingConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
|
||||
jfieldID fid = env->GetFieldID(
|
||||
cls, "model", "Lcom/k2fsa/sherpa/onnx/AudioTaggingModelConfig;");
|
||||
jobject model = env->GetObjectField(config, fid);
|
||||
jclass model_cls = env->GetObjectClass(model);
|
||||
|
||||
fid = env->GetFieldID(
|
||||
model_cls, "zipformer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineZipformerAudioTaggingModelConfig;");
|
||||
jobject zipformer = env->GetObjectField(model, fid);
|
||||
jclass zipformer_cls = env->GetObjectClass(zipformer);
|
||||
|
||||
fid = env->GetFieldID(zipformer_cls, "model", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(zipformer, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.zipformer.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_cls, "numThreads", "I");
|
||||
ans.model.num_threads = env->GetIntField(model, fid);
|
||||
|
||||
fid = env->GetFieldID(model_cls, "debug", "Z");
|
||||
ans.model.debug = env->GetBooleanField(model, fid);
|
||||
|
||||
fid = env->GetFieldID(model_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "labels", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.labels = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "topK", "I");
|
||||
ans.top_k = env->GetIntField(config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetAudioTaggingConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("audio tagging newFromFile config:\n%s",
|
||||
config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto tagger = new sherpa_onnx::AudioTagging(config);
|
||||
|
||||
return (jlong)tagger;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_createStream(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
|
||||
std::unique_ptr<sherpa_onnx::OfflineStream> s = tagger->CreateStream();
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
|
||||
// ./offline-stream.cc
|
||||
sherpa_onnx::OfflineStream *p = s.release();
|
||||
return (jlong)p;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_compute(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr, jint top_k) {
|
||||
auto tagger = reinterpret_cast<sherpa_onnx::AudioTagging *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
|
||||
std::vector<sherpa_onnx::AudioEvent> events = tagger->Compute(stream, top_k);
|
||||
|
||||
// TODO(fangjun): Return an array of AudioEvent directly
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
events.size(), env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
int32_t i = 0;
|
||||
for (const auto &e : events) {
|
||||
jobjectArray a = (jobjectArray)env->NewObjectArray(
|
||||
3, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
// 0 name
|
||||
// 1 index
|
||||
// 2 prob
|
||||
jstring js = env->NewStringUTF(e.name.c_str());
|
||||
env->SetObjectArrayElement(a, 0, js);
|
||||
env->SetObjectArrayElement(a, 1, NewInteger(env, e.index));
|
||||
env->SetObjectArrayElement(a, 2, NewFloat(env, e.prob));
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, i, a);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
23
sherpa-onnx/jni/common.h
Normal file
23
sherpa-onnx/jni/common.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/jni/common.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_JNI_COMMON_H_
|
||||
#define SHERPA_ONNX_JNI_COMMON_H_
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
// If you use ndk, you can find "jni.h" inside
|
||||
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
#include "jni.h" // NOLINT
|
||||
|
||||
#define SHERPA_ONNX_EXTERN_C extern "C"
|
||||
|
||||
// defined in jni.cc
|
||||
jobject NewInteger(JNIEnv *env, int32_t value);
|
||||
jobject NewFloat(JNIEnv *env, float value);
|
||||
|
||||
#endif // SHERPA_ONNX_JNI_COMMON_H_
|
||||
@@ -7,20 +7,11 @@
|
||||
// TODO(fangjun): Add documentation to functions/methods in this file
|
||||
// and also show how to use them with kotlin, possibly with java.
|
||||
|
||||
// If you use ndk, you can find "jni.h" inside
|
||||
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
#include "jni.h" // NOLINT
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <strstream>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -31,13 +22,12 @@
|
||||
#include "sherpa-onnx/csrc/voice-activity-detector.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS == 1
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
#endif
|
||||
|
||||
#define SHERPA_ONNX_EXTERN_C extern "C"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SherpaOnnx {
|
||||
@@ -1224,12 +1214,18 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
|
||||
|
||||
// see
|
||||
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
|
||||
static jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
jclass cls = env->FindClass("java/lang/Integer");
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
|
||||
return env->NewObject(cls, constructor, value);
|
||||
}
|
||||
|
||||
jobject NewFloat(JNIEnv *env, float value) {
|
||||
jclass cls = env->FindClass("java/lang/Float");
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(F)V");
|
||||
return env->NewObject(cls, constructor, value);
|
||||
}
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS == 1
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
|
||||
|
||||
25
sherpa-onnx/jni/offline-stream.cc
Normal file
25
sherpa-onnx/jni/offline-stream.cc
Normal file
@@ -0,0 +1,25 @@
|
||||
// sherpa-onnx/jni/offline-stream.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineStream_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
stream->AcceptWaveform(sample_rate, p, n);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
Reference in New Issue
Block a user