Refactor the JNI interface to make it more modular and maintainable (#802)

This commit is contained in:
Fangjun Kuang
2024-04-24 09:48:42 +08:00
committed by GitHub
parent dc5af04830
commit 9b67a476e6
116 changed files with 3502 additions and 3316 deletions

View File

@@ -1,4 +1,4 @@
@file:OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@file:OptIn(ExperimentalMaterial3Api::class)
package com.k2fsa.sherpa.onnx.slid
@@ -9,11 +9,9 @@ import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.ui.Modifier
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
@@ -31,6 +29,7 @@ import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
@@ -63,13 +62,13 @@ fun Home() {
}
private var audioRecord: AudioRecord? = null
private val sampleRateInHz = 16000
private const val sampleRateInHz = 16000
@Composable
fun MyApp(padding: PaddingValues) {
val activity = LocalContext.current as Activity
var isStarted by remember { mutableStateOf(false) }
var result by remember { mutableStateOf<String>("") }
var result by remember { mutableStateOf("") }
val onButtonClick: () -> Unit = {
isStarted = !isStarted
@@ -114,12 +113,12 @@ fun MyApp(padding: PaddingValues) {
}
Log.i(TAG, "Stop recording")
Log.i(TAG, "Start recognition")
val samples = Flatten(sampleList)
val samples = flatten(sampleList)
val stream = Slid.slid.createStream()
stream.acceptWaveform(samples, sampleRateInHz)
val lang = Slid.slid.compute(stream)
result = Slid.localeMap.get(lang) ?: lang
result = Slid.localeMap[lang] ?: lang
stream.release()
}
@@ -152,7 +151,7 @@ fun MyApp(padding: PaddingValues) {
}
}
fun Flatten(sampleList: ArrayList<FloatArray>): FloatArray {
fun flatten(sampleList: ArrayList<FloatArray>): FloatArray {
var totalSamples = 0
for (a in sampleList) {
totalSamples += a.size

View File

@@ -10,12 +10,9 @@ import androidx.activity.compose.setContent
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.SpokenLanguageIdentification
import com.k2fsa.sherpa.onnx.slid.ui.theme.SherpaOnnxSpokenLanguageIdentificationTheme
const val TAG = "sherpa-onnx"
@@ -32,6 +29,7 @@ class MainActivity : ComponentActivity() {
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
Slid.initSlid(this.assets)
}
@Suppress("DEPRECATION")
@Deprecated("Deprecated in Java")
override fun onRequestPermissionsResult(

View File

@@ -1 +1 @@
../../../../../../../../../../SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt
../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt

View File

@@ -1,102 +0,0 @@
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx"
data class SpokenLanguageIdentificationWhisperConfig (
var encoder: String,
var decoder: String,
var tailPaddings: Int = -1,
)
data class SpokenLanguageIdentificationConfig (
var whisper: SpokenLanguageIdentificationWhisperConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
class SpokenLanguageIdentification (
assetManager: AssetManager? = null,
config: SpokenLanguageIdentificationConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun createStream(): OfflineStream {
val p = createStream(ptr)
return OfflineStream(p)
}
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)
private external fun newFromAsset(
assetManager: AssetManager,
config: SpokenLanguageIdentificationConfig,
): Long
private external fun newFromFile(
config: SpokenLanguageIdentificationConfig,
): Long
private external fun delete(ptr: Long)
private external fun createStream(ptr: Long): Long
private external fun compute(ptr: Long, streamPtr: Long): String
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper
// to download more models
fun getSpokenLanguageIdentificationConfig(type: Int, numThreads: Int=1): SpokenLanguageIdentificationConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-whisper-tiny"
return SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "$modelDir/tiny-encoder.int8.onnx",
decoder = "$modelDir/tiny-decoder.int8.onnx",
),
numThreads = numThreads,
debug = true,
)
}
1 -> {
val modelDir = "sherpa-onnx-whisper-base"
return SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "$modelDir/tiny-encoder.int8.onnx",
decoder = "$modelDir/tiny-decoder.int8.onnx",
),
numThreads = 1,
debug = true,
)
}
}
return null
}

View File

@@ -0,0 +1 @@
../../../../../../../../../../../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt

View File

@@ -15,10 +15,10 @@ object Slid {
get() {
return _slid!!
}
val localeMap : Map<String, String>
get() {
return _localeMap
}
val localeMap: Map<String, String>
get() {
return _localeMap
}
fun initSlid(assetManager: AssetManager? = null, numThreads: Int = 1) {
synchronized(this) {
@@ -31,7 +31,7 @@ object Slid {
}
if (_localeMap.isEmpty()) {
val allLang = Locale.getISOLanguages();
val allLang = Locale.getISOLanguages()
for (lang in allLang) {
val locale = Locale(lang)
_localeMap[lang] = locale.displayName