This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/kotlin-api-examples/Main.kt

246 lines
7.9 KiB
Kotlin

package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
fun callback(samples: FloatArray): Unit {
println("callback got called with ${samples.size} samples");
}
fun main() {
testSpokenLanguageIdentifcation()
testAudioTagging()
testSpeakerRecognition()
testTts()
testAsr("transducer")
testAsr("zipformer2-ctc")
}
fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(assetManager=null, config=config)
val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)
for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
}
fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
zipformer=OfflineZipformerAudioTaggingModelConfig(
model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx",
),
numThreads=1,
debug=true,
provider="cpu",
),
labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv",
topK=5,
)
val tagger = AudioTagging(assetManager=null, config=config)
val testFiles = arrayOf(
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav",
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav",
)
println("----------")
for (waveFilename in testFiles) {
val stream = tagger.createStream()
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
stream.acceptWaveform(samples, sampleRate = sampleRate)
val events = tagger.compute(stream)
stream.release()
println(waveFilename)
println(events)
println("----------")
}
tagger.release()
}
fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray {
var objArray = WaveReader.readWaveFromFile(
filename = filename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
val stream = extractor.createStream()
stream.acceptWaveform(sampleRate = sampleRate, samples=samples)
stream.inputFinished()
check(extractor.isReady(stream))
val embedding = extractor.compute(stream)
stream.release()
return embedding
}
fun testSpeakerRecognition() {
val config = SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
)
val extractor = SpeakerEmbeddingExtractor(config = config)
val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav")
val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav")
val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav")
var manager = SpeakerEmbeddingManager(extractor.dim())
var ok = manager.add(name = "speaker1", embedding=embedding1a)
check(ok)
manager.add(name = "speaker2", embedding=embedding2a)
check(ok)
var name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "speaker1")
manager.release()
manager = SpeakerEmbeddingManager(extractor.dim())
val embeddingList = mutableListOf(embedding1a, embedding1b)
ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray())
check(ok)
name = manager.search(embedding=embedding1b, threshold=0.5f)
check(name == "s1")
name = manager.search(embedding=embedding2a, threshold=0.5f)
check(name.length == 0)
manager.release()
}
fun testTts() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
vits=OfflineTtsVitsModelConfig(
model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx",
tokens="./vits-piper-en_US-amy-low/tokens.txt",
dataDir="./vits-piper-en_US-amy-low/espeak-ng-data",
),
numThreads=1,
debug=true,
)
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
audio.save(filename="test-en.wav")
}
fun testAsr(type: String) {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)
var waveFilename: String
var modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}
var endpointConfig = EndpointConfig()
var lmConfig = OnlineLMConfig()
var config = OnlineRecognizerConfig(
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
var model = SherpaOnnx(
config = config,
)
var objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
model.acceptWaveform(samples, sampleRate = sampleRate)
while (model.isReady()) {
model.decode()
}
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
model.inputFinished()
while (model.isReady()) {
model.decode()
}
println("results: ${model.text}")
}