Add Koltin and Java API for Kokoro TTS models (#1728)

This commit is contained in:
Fangjun Kuang
2025-01-17 17:36:13 +08:00
committed by GitHub
parent 3a1de0bfc1
commit 99cef4198b
18 changed files with 548 additions and 39 deletions

View File

@@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() {
var modelName: String?
var acousticModelName: String?
var vocoder: String?
var voices: String?
var ruleFsts: String?
var ruleFars: String?
var lexicon: String?
@@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() {
vocoder = null
// Matcha -- end
// For Kokoro -- begin
voices = null
// For Kokoro -- end
modelDir = null
ruleFsts = null
@@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
if (dataDir != null) {
val newDir = copyDataDir(dataDir!!)
dataDir = "$newDir/$dataDir"
@@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",

View File

@@ -47,7 +47,7 @@ fun getSampleText(lang: String): String {
}
"eng" -> {
text = "This is a text-to-speech engine using next generation Kaldi"
text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi"
}
"est" -> {

View File

@@ -3,6 +3,10 @@
package com.k2fsa.sherpa.onnx.tts.engine
import PreferenceHelper
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.media.MediaPlayer
import android.net.Uri
import android.os.Bundle
@@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.unit.dp
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import kotlin.time.TimeSource
const val TAG = "sherpa-onnx-tts-engine"
@@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() {
private val ttsViewModel: TtsViewModel by viewModels()
private var mediaPlayer: MediaPlayer? = null
// see
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
private lateinit var track: AudioTrack
private var stopped: Boolean = false
private var samplesChannel = Channel<FloatArray>()
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Log.i(TAG, "Start to initialize TTS")
TtsEngine.createTts(this)
Log.i(TAG, "Finish initializing TTS")
Log.i(TAG, "Start to initialize AudioTrack")
initAudioTrack()
Log.i(TAG, "Finish initializing AudioTrack")
val preferenceHelper = PreferenceHelper(this)
setContent {
SherpaOnnxTtsEngineTheme {
@@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() {
val testTextContent = getSampleText(TtsEngine.lang ?: "")
var testText by remember { mutableStateOf(testTextContent) }
var startEnabled by remember { mutableStateOf(true) }
var playEnabled by remember { mutableStateOf(false) }
var rtfText by remember {
mutableStateOf("")
}
val numSpeakers = TtsEngine.tts!!.numSpeakers()
if (numSpeakers > 1) {
@@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() {
Row {
Button(
modifier = Modifier.padding(20.dp),
enabled = startEnabled,
modifier = Modifier.padding(5.dp),
onClick = {
Log.i(TAG, "Clicked, text: $testText")
if (testText.isBlank() || testText.isEmpty()) {
Toast.makeText(
applicationContext,
"Please input a test sentence",
"Please input some text to generate",
Toast.LENGTH_SHORT
).show()
} else {
val audio = TtsEngine.tts!!.generate(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
)
startEnabled = false
playEnabled = false
stopped = false
val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)
track.pause()
track.flush()
track.play()
rtfText = ""
Log.i(TAG, "Started with text $testText")
if (ok) {
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
} else {
Log.i(TAG, "Failed to generate or save audio")
samplesChannel = Channel<FloatArray>()
CoroutineScope(Dispatchers.IO).launch {
for (samples in samplesChannel) {
track.write(
samples,
0,
samples.size,
AudioTrack.WRITE_BLOCKING
)
if (stopped) {
break
}
}
}
CoroutineScope(Dispatchers.Default).launch {
val timeSource = TimeSource.Monotonic
val startTime = timeSource.markNow()
val audio =
TtsEngine.tts!!.generateWithCallback(
text = testText,
sid = TtsEngine.speakerId,
speed = TtsEngine.speed,
callback = ::callback,
)
val elapsed =
startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000;
val audioDuration =
audio.samples.size / TtsEngine.tts!!.sampleRate()
.toFloat()
val RTF = String.format(
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
TtsEngine.tts!!.config.model.numThreads,
audioDuration,
elapsed,
elapsed,
audioDuration,
elapsed / audioDuration
)
samplesChannel.close()
val filename =
application.filesDir.absolutePath + "/generated.wav"
val ok =
audio.samples.isNotEmpty() && audio.save(
filename
)
if (ok) {
withContext(Dispatchers.Main) {
startEnabled = true
playEnabled = true
rtfText = RTF
}
}
}.start()
}
}) {
Text("Test")
Text("Start")
}
Button(
modifier = Modifier.padding(20.dp),
modifier = Modifier.padding(5.dp),
enabled = playEnabled,
onClick = {
TtsEngine.speakerId = 0
TtsEngine.speed = 1.0f
testText = ""
stopped = true
track.pause()
track.flush()
onClickPlay()
}) {
Text("Reset")
Text("Play")
}
Button(
modifier = Modifier.padding(5.dp),
onClick = {
onClickStop()
startEnabled = true
}) {
Text("Stop")
}
}
if (rtfText.isNotEmpty()) {
Row {
Text(rtfText)
}
}
}
@@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() {
mediaPlayer?.release()
mediaPlayer = null
}
private fun onClickPlay() {
val filename = application.filesDir.absolutePath + "/generated.wav"
stopMediaPlayer()
mediaPlayer = MediaPlayer.create(
applicationContext,
Uri.fromFile(File(filename))
)
mediaPlayer?.start()
}
private fun onClickStop() {
stopped = true
track.pause()
track.flush()
stopMediaPlayer()
}
// this function is called from C++
private fun callback(samples: FloatArray): Int {
if (!stopped) {
val samplesCopy = samples.copyOf()
CoroutineScope(Dispatchers.IO).launch {
samplesChannel.send(samplesCopy)
}
return 1
} else {
track.stop()
Log.i(TAG, " return 0")
return 0
}
}
private fun initAudioTrack() {
val sampleRate = TtsEngine.tts!!.sampleRate()
val bufLength = AudioTrack.getMinBufferSize(
sampleRate,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()
track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track.play()
}
}

View File

@@ -41,8 +41,9 @@ object TtsEngine {
private var modelDir: String? = null
private var modelName: String? = null
private var acousticModelName: String? = null
private var vocoder: String? = null
private var acousticModelName: String? = null // for matcha tts
private var vocoder: String? = null // for matcha tts
private var voices: String? = null // for kokoro
private var ruleFsts: String? = null
private var ruleFars: String? = null
private var lexicon: String? = null
@@ -64,6 +65,10 @@ object TtsEngine {
vocoder = null
// For Matcha -- end
// For Kokoro -- begin
voices = null
// For Kokoro -- end
modelDir = null
ruleFsts = null
ruleFars = null
@@ -139,6 +144,14 @@ object TtsEngine {
// vocoder = "hifigan_v2.onnx"
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
// lang = "eng"
// Example 9
// kokoro-en-v0_19
// modelDir = "kokoro-en-v0_19"
// modelName = "model.onnx"
// voices = "voices.bin"
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
// lang = "eng"
}
fun createTts(context: Context) {
@@ -167,6 +180,7 @@ object TtsEngine {
modelName = modelName ?: "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = voices ?: "",
lexicon = lexicon ?: "",
dataDir = dataDir ?: "",
dictDir = dictDir ?: "",

View File

@@ -1,3 +1,3 @@
<resources>
<string name="app_name">TTS Engine</string>
<string name="app_name">TTS Engine: Next-gen Kaldi</string>
</resources>