diff --git a/sherpa-onnx/kotlin-api/KeywordSpotter.kt b/sherpa-onnx/kotlin-api/KeywordSpotter.kt index 803762e5..ea214361 100644 --- a/sherpa-onnx/kotlin-api/KeywordSpotter.kt +++ b/sherpa-onnx/kotlin-api/KeywordSpotter.kt @@ -24,7 +24,7 @@ class KeywordSpotter( assetManager: AssetManager? = null, val config: KeywordSpotterConfig, ) { - private val ptr: Long + private var ptr: Long init { ptr = if (assetManager != null) { @@ -35,7 +35,10 @@ class KeywordSpotter( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } fun release() = finalize() diff --git a/sherpa-onnx/kotlin-api/OfflinePunctuation.kt b/sherpa-onnx/kotlin-api/OfflinePunctuation.kt index ec6aa08f..8dfcbd63 100644 --- a/sherpa-onnx/kotlin-api/OfflinePunctuation.kt +++ b/sherpa-onnx/kotlin-api/OfflinePunctuation.kt @@ -18,7 +18,7 @@ class OfflinePunctuation( assetManager: AssetManager? = null, config: OfflinePunctuationConfig, ) { - private val ptr: Long + private var ptr: Long init { ptr = if (assetManager != null) { @@ -29,7 +29,10 @@ class OfflinePunctuation( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } fun release() = finalize() diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 0003eb42..f48cc3fd 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -72,7 +72,7 @@ class OfflineRecognizer( assetManager: AssetManager? = null, config: OfflineRecognizerConfig, ) { - private val ptr: Long + private var ptr: Long init { ptr = if (assetManager != null) { @@ -83,7 +83,10 @@ class OfflineRecognizer( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } fun release() = finalize() @@ -102,7 +105,14 @@ class OfflineRecognizer( val lang = objArray[3] as String val emotion = objArray[4] as String val event = objArray[5] as String - return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps, lang = lang, emotion = emotion, event = event) + return OfflineRecognizerResult( + text = text, + tokens = tokens, + timestamps = timestamps, + lang = lang, + emotion = emotion, + event = event + ) } fun decode(stream: OfflineStream) = decode(ptr, stream.ptr) diff --git a/sherpa-onnx/kotlin-api/OfflineStream.kt b/sherpa-onnx/kotlin-api/OfflineStream.kt index 49652e72..ab316c52 100644 --- a/sherpa-onnx/kotlin-api/OfflineStream.kt +++ b/sherpa-onnx/kotlin-api/OfflineStream.kt @@ -13,6 +13,14 @@ class OfflineStream(var ptr: Long) { fun release() = finalize() + fun use(block: (OfflineStream) -> Unit) { + try { + block(this) + } finally { + release() + } + } + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) private external fun delete(ptr: Long) diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index 4750e232..7ddefdf3 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -62,7 +62,7 @@ data class OnlineRecognizerConfig( var featConfig: FeatureConfig = FeatureConfig(), var modelConfig: OnlineModelConfig, var lmConfig: OnlineLMConfig = OnlineLMConfig(), - var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(), + var ctcFstDecoderConfig: OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(), var endpointConfig: EndpointConfig = EndpointConfig(), var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search", @@ -85,7 +85,7 @@ class OnlineRecognizer( assetManager: AssetManager? = null, val config: OnlineRecognizerConfig, ) { - private val ptr: Long + private var ptr: Long init { ptr = if (assetManager != null) { @@ -96,7 +96,10 @@ class OnlineRecognizer( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } fun release() = finalize() diff --git a/sherpa-onnx/kotlin-api/OnlineStream.kt b/sherpa-onnx/kotlin-api/OnlineStream.kt index 6057fabd..cb15a9a4 100644 --- a/sherpa-onnx/kotlin-api/OnlineStream.kt +++ b/sherpa-onnx/kotlin-api/OnlineStream.kt @@ -15,10 +15,19 @@ class OnlineStream(var ptr: Long = 0) { fun release() = finalize() + fun use(block: (OnlineStream) -> Unit) { + try { + block(this) + } finally { + release() + } + } + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) private external fun inputFinished(ptr: Long) private external fun delete(ptr: Long) + companion object { init { System.loadLibrary("sherpa-onnx-jni") diff --git a/sherpa-onnx/kotlin-api/Vad.kt b/sherpa-onnx/kotlin-api/Vad.kt index c842e035..182a23d5 100644 --- a/sherpa-onnx/kotlin-api/Vad.kt +++ b/sherpa-onnx/kotlin-api/Vad.kt @@ -19,11 +19,13 @@ data class VadModelConfig( var debug: Boolean = false, ) +class SpeechSegment(val start: Int, val samples: FloatArray) + class Vad( assetManager: AssetManager? = null, var config: VadModelConfig, ) { - private val ptr: Long + private var ptr: Long init { if (assetManager != null) { @@ -34,17 +36,23 @@ class Vad( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } + fun release() = finalize() + fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) fun empty(): Boolean = empty(ptr) fun pop() = pop(ptr) - // return an array containing - // [start: Int, samples: FloatArray] - fun front() = front(ptr) + fun front(): SpeechSegment { + val segment = front(ptr) + return SpeechSegment(segment[0] as Int, segment[1] as FloatArray) + } fun clear() = clear(ptr) diff --git a/sherpa-onnx/kotlin-api/WaveReader.kt b/sherpa-onnx/kotlin-api/WaveReader.kt index dca39984..9e4f5fbd 100644 --- a/sherpa-onnx/kotlin-api/WaveReader.kt +++ b/sherpa-onnx/kotlin-api/WaveReader.kt @@ -3,8 +3,49 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager +data class WaveData( + val samples: FloatArray, + val sampleRate: Int, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as WaveData + + if (!samples.contentEquals(other.samples)) return false + if (sampleRate != other.sampleRate) return false + + return true + } + + override fun hashCode(): Int { + var result = samples.contentHashCode() + result = 31 * result + sampleRate + return result + } +} + class WaveReader { companion object { + + fun readWave( + assetManager: AssetManager, + filename: String, + ): WaveData { + return readWaveFromAsset(assetManager, filename).let { + WaveData(it[0] as FloatArray, it[1] as Int) + } + } + + fun readWave( + filename: String, + ): WaveData { + return readWaveFromFile(filename).let { + WaveData(it[0] as FloatArray, it[1] as Int) + } + } + // Read a mono wave file asset // The returned array has two entries: // - the first entry contains an 1-D float array