2023-02-22 21:14:57 +08:00
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
2023-12-09 16:36:38 +08:00
fun callback ( samples : FloatArray ) : Unit {
println ( " callback got called with ${samples.size} samples " ) ;
}
2023-02-22 21:14:57 +08:00
fun main ( ) {
2024-01-23 16:50:52 +08:00
testSpeakerRecognition ( )
2023-10-23 12:31:54 +08:00
testTts ( )
2023-12-22 13:46:33 +08:00
testAsr ( " transducer " )
testAsr ( " zipformer2-ctc " )
2023-10-23 12:31:54 +08:00
}
2024-01-23 16:50:52 +08:00
fun computeEmbedding ( extractor : SpeakerEmbeddingExtractor , filename : String ) : FloatArray {
var objArray = WaveReader . readWaveFromFile (
filename = filename ,
)
var samples : FloatArray = objArray [ 0 ] as FloatArray
var sampleRate : Int = objArray [ 1 ] as Int
val stream = extractor . createStream ( )
stream . acceptWaveform ( sampleRate = sampleRate , samples = samples )
stream . inputFinished ( )
check ( extractor . isReady ( stream ) )
val embedding = extractor . compute ( stream )
stream . release ( )
return embedding
}
fun testSpeakerRecognition ( ) {
val config = SpeakerEmbeddingExtractorConfig (
model = " ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx " ,
)
val extractor = SpeakerEmbeddingExtractor ( config = config )
val embedding1a = computeEmbedding ( extractor , " ./speaker1_a_cn_16k.wav " )
val embedding2a = computeEmbedding ( extractor , " ./speaker2_a_cn_16k.wav " )
val embedding1b = computeEmbedding ( extractor , " ./speaker1_b_cn_16k.wav " )
var manager = SpeakerEmbeddingManager ( extractor . dim ( ) )
var ok = manager . add ( name = " speaker1 " , embedding = embedding1a )
check ( ok )
manager . add ( name = " speaker2 " , embedding = embedding2a )
check ( ok )
var name = manager . search ( embedding = embedding1b , threshold = 0.5f )
check ( name == " speaker1 " )
manager . release ( )
manager = SpeakerEmbeddingManager ( extractor . dim ( ) )
val embeddingList = mutableListOf ( embedding1a , embedding1b )
ok = manager . add ( name = " s1 " , embedding = embeddingList . toTypedArray ( ) )
check ( ok )
name = manager . search ( embedding = embedding1b , threshold = 0.5f )
check ( name == " s1 " )
name = manager . search ( embedding = embedding2a , threshold = 0.5f )
check ( name . length == 0 )
manager . release ( )
}
2023-10-23 12:31:54 +08:00
fun testTts ( ) {
2023-11-30 23:57:43 +08:00
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
2023-10-23 12:31:54 +08:00
var config = OfflineTtsConfig (
model = OfflineTtsModelConfig (
vits = OfflineTtsVitsModelConfig (
2023-11-30 23:57:43 +08:00
model = " ./vits-piper-en_US-amy-low/en_US-amy-low.onnx " ,
tokens = " ./vits-piper-en_US-amy-low/tokens.txt " ,
dataDir = " ./vits-piper-en_US-amy-low/espeak-ng-data " ,
2023-10-23 12:31:54 +08:00
) ,
numThreads = 1 ,
debug = true ,
)
)
val tts = OfflineTts ( config = config )
2023-12-09 16:36:38 +08:00
val audio = tts . generateWithCallback ( text = " “Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.” " , callback = :: callback )
2023-11-30 23:57:43 +08:00
audio . save ( filename = " test-en.wav " )
2023-10-23 12:31:54 +08:00
}
2023-12-22 13:46:33 +08:00
fun testAsr ( type : String ) {
2023-02-24 13:57:03 +08:00
var featConfig = FeatureConfig (
2023-03-03 15:18:31 +08:00
sampleRate = 16000 ,
2023-02-24 13:57:03 +08:00
featureDim = 80 ,
)
2023-12-22 13:46:33 +08:00
var waveFilename : String
var modelConfig : OnlineModelConfig = when ( type ) {
" transducer " -> {
waveFilename = " ./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav "
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig (
transducer = OnlineTransducerModelConfig (
encoder = " ./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx " ,
decoder = " ./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx " ,
joiner = " ./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx " ,
) ,
tokens = " ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt " ,
numThreads = 1 ,
debug = false ,
)
}
" zipformer2-ctc " -> {
waveFilename = " ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav "
OnlineModelConfig (
zipformer2Ctc = OnlineZipformer2CtcModelConfig (
model = " ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx " ,
) ,
tokens = " ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt " ,
numThreads = 1 ,
debug = false ,
)
}
else -> throw IllegalArgumentException ( type )
}
2023-02-24 13:57:03 +08:00
var endpointConfig = EndpointConfig ( )
2023-05-05 21:23:54 +08:00
var lmConfig = OnlineLMConfig ( )
2023-02-24 13:57:03 +08:00
var config = OnlineRecognizerConfig (
modelConfig = modelConfig ,
2023-05-05 21:23:54 +08:00
lmConfig = lmConfig ,
2023-02-24 13:57:03 +08:00
featConfig = featConfig ,
endpointConfig = endpointConfig ,
enableEndpoint = true ,
2023-03-03 15:18:31 +08:00
decodingMethod = " greedy_search " ,
maxActivePaths = 4 ,
2023-02-24 13:57:03 +08:00
)
var model = SherpaOnnx (
config = config ,
)
2023-03-03 15:18:31 +08:00
2023-04-19 17:29:35 +08:00
var objArray = WaveReader . readWaveFromFile (
2023-12-22 13:46:33 +08:00
filename = waveFilename ,
2023-02-24 13:57:03 +08:00
)
2023-08-16 00:28:52 +08:00
var samples : FloatArray = objArray [ 0 ] as FloatArray
var sampleRate : Int = objArray [ 1 ] as Int
2023-02-24 13:57:03 +08:00
2023-08-16 00:28:52 +08:00
model . acceptWaveform ( samples , sampleRate = sampleRate )
2023-03-03 15:18:31 +08:00
while ( model . isReady ( ) ) {
2023-08-16 00:28:52 +08:00
model . decode ( )
2023-03-03 15:18:31 +08:00
}
2023-02-24 13:57:03 +08:00
2023-04-19 17:29:35 +08:00
var tailPaddings = FloatArray ( ( sampleRate * 0.5 ) . toInt ( ) ) // 0.5 seconds
2023-08-16 00:28:52 +08:00
model . acceptWaveform ( tailPaddings , sampleRate = sampleRate )
2023-02-24 13:57:03 +08:00
model . inputFinished ( )
2023-03-03 15:18:31 +08:00
while ( model . isReady ( ) ) {
2023-08-16 00:28:52 +08:00
model . decode ( )
2023-03-03 15:18:31 +08:00
}
2023-02-24 13:57:03 +08:00
println ( " results: ${model.text} " )
2023-02-22 21:14:57 +08:00
}