Add Koltin and Java API for Kokoro TTS models (#1728)
This commit is contained in:
@@ -185,6 +185,7 @@ class MainActivity : AppCompatActivity() {
|
||||
var modelName: String?
|
||||
var acousticModelName: String?
|
||||
var vocoder: String?
|
||||
var voices: String?
|
||||
var ruleFsts: String?
|
||||
var ruleFars: String?
|
||||
var lexicon: String?
|
||||
@@ -205,6 +206,10 @@ class MainActivity : AppCompatActivity() {
|
||||
vocoder = null
|
||||
// Matcha -- end
|
||||
|
||||
// For Kokoro -- begin
|
||||
voices = null
|
||||
// For Kokoro -- end
|
||||
|
||||
|
||||
modelDir = null
|
||||
ruleFsts = null
|
||||
@@ -269,6 +274,13 @@ class MainActivity : AppCompatActivity() {
|
||||
// vocoder = "hifigan_v2.onnx"
|
||||
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
|
||||
|
||||
// Example 9
|
||||
// kokoro-en-v0_19
|
||||
// modelDir = "kokoro-en-v0_19"
|
||||
// modelName = "model.onnx"
|
||||
// voices = "voices.bin"
|
||||
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
|
||||
|
||||
if (dataDir != null) {
|
||||
val newDir = copyDataDir(dataDir!!)
|
||||
dataDir = "$newDir/$dataDir"
|
||||
@@ -285,6 +297,7 @@ class MainActivity : AppCompatActivity() {
|
||||
modelName = modelName ?: "",
|
||||
acousticModelName = acousticModelName ?: "",
|
||||
vocoder = vocoder ?: "",
|
||||
voices = voices ?: "",
|
||||
lexicon = lexicon ?: "",
|
||||
dataDir = dataDir ?: "",
|
||||
dictDir = dictDir ?: "",
|
||||
|
||||
@@ -47,7 +47,7 @@ fun getSampleText(lang: String): String {
|
||||
}
|
||||
|
||||
"eng" -> {
|
||||
text = "This is a text-to-speech engine using next generation Kaldi"
|
||||
text = "How are you doing today? This is a text-to-speech engine using next generation Kaldi"
|
||||
}
|
||||
|
||||
"est" -> {
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
package com.k2fsa.sherpa.onnx.tts.engine
|
||||
|
||||
import PreferenceHelper
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioManager
|
||||
import android.media.AudioTrack
|
||||
import android.media.MediaPlayer
|
||||
import android.net.Uri
|
||||
import android.os.Bundle
|
||||
@@ -36,7 +40,13 @@ import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import kotlin.time.TimeSource
|
||||
|
||||
const val TAG = "sherpa-onnx-tts-engine"
|
||||
|
||||
@@ -45,9 +55,26 @@ class MainActivity : ComponentActivity() {
|
||||
private val ttsViewModel: TtsViewModel by viewModels()
|
||||
|
||||
private var mediaPlayer: MediaPlayer? = null
|
||||
|
||||
// see
|
||||
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
|
||||
private lateinit var track: AudioTrack
|
||||
|
||||
private var stopped: Boolean = false
|
||||
|
||||
private var samplesChannel = Channel<FloatArray>()
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
|
||||
Log.i(TAG, "Start to initialize TTS")
|
||||
TtsEngine.createTts(this)
|
||||
Log.i(TAG, "Finish initializing TTS")
|
||||
|
||||
Log.i(TAG, "Start to initialize AudioTrack")
|
||||
initAudioTrack()
|
||||
Log.i(TAG, "Finish initializing AudioTrack")
|
||||
|
||||
val preferenceHelper = PreferenceHelper(this)
|
||||
setContent {
|
||||
SherpaOnnxTtsEngineTheme {
|
||||
@@ -77,6 +104,11 @@ class MainActivity : ComponentActivity() {
|
||||
val testTextContent = getSampleText(TtsEngine.lang ?: "")
|
||||
|
||||
var testText by remember { mutableStateOf(testTextContent) }
|
||||
var startEnabled by remember { mutableStateOf(true) }
|
||||
var playEnabled by remember { mutableStateOf(false) }
|
||||
var rtfText by remember {
|
||||
mutableStateOf("")
|
||||
}
|
||||
|
||||
val numSpeakers = TtsEngine.tts!!.numSpeakers()
|
||||
if (numSpeakers > 1) {
|
||||
@@ -119,52 +151,117 @@ class MainActivity : ComponentActivity() {
|
||||
|
||||
Row {
|
||||
Button(
|
||||
modifier = Modifier.padding(20.dp),
|
||||
enabled = startEnabled,
|
||||
modifier = Modifier.padding(5.dp),
|
||||
onClick = {
|
||||
Log.i(TAG, "Clicked, text: $testText")
|
||||
if (testText.isBlank() || testText.isEmpty()) {
|
||||
Toast.makeText(
|
||||
applicationContext,
|
||||
"Please input a test sentence",
|
||||
"Please input some text to generate",
|
||||
Toast.LENGTH_SHORT
|
||||
).show()
|
||||
} else {
|
||||
val audio = TtsEngine.tts!!.generate(
|
||||
text = testText,
|
||||
sid = TtsEngine.speakerId,
|
||||
speed = TtsEngine.speed,
|
||||
)
|
||||
startEnabled = false
|
||||
playEnabled = false
|
||||
stopped = false
|
||||
|
||||
val filename =
|
||||
application.filesDir.absolutePath + "/generated.wav"
|
||||
val ok =
|
||||
audio.samples.isNotEmpty() && audio.save(
|
||||
filename
|
||||
)
|
||||
track.pause()
|
||||
track.flush()
|
||||
track.play()
|
||||
rtfText = ""
|
||||
Log.i(TAG, "Started with text $testText")
|
||||
|
||||
if (ok) {
|
||||
stopMediaPlayer()
|
||||
mediaPlayer = MediaPlayer.create(
|
||||
applicationContext,
|
||||
Uri.fromFile(File(filename))
|
||||
)
|
||||
mediaPlayer?.start()
|
||||
} else {
|
||||
Log.i(TAG, "Failed to generate or save audio")
|
||||
samplesChannel = Channel<FloatArray>()
|
||||
|
||||
CoroutineScope(Dispatchers.IO).launch {
|
||||
for (samples in samplesChannel) {
|
||||
track.write(
|
||||
samples,
|
||||
0,
|
||||
samples.size,
|
||||
AudioTrack.WRITE_BLOCKING
|
||||
)
|
||||
if (stopped) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CoroutineScope(Dispatchers.Default).launch {
|
||||
val timeSource = TimeSource.Monotonic
|
||||
val startTime = timeSource.markNow()
|
||||
|
||||
val audio =
|
||||
TtsEngine.tts!!.generateWithCallback(
|
||||
text = testText,
|
||||
sid = TtsEngine.speakerId,
|
||||
speed = TtsEngine.speed,
|
||||
callback = ::callback,
|
||||
)
|
||||
|
||||
val elapsed =
|
||||
startTime.elapsedNow().inWholeMilliseconds.toFloat() / 1000;
|
||||
val audioDuration =
|
||||
audio.samples.size / TtsEngine.tts!!.sampleRate()
|
||||
.toFloat()
|
||||
val RTF = String.format(
|
||||
"Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f",
|
||||
TtsEngine.tts!!.config.model.numThreads,
|
||||
audioDuration,
|
||||
elapsed,
|
||||
elapsed,
|
||||
audioDuration,
|
||||
elapsed / audioDuration
|
||||
)
|
||||
samplesChannel.close()
|
||||
|
||||
val filename =
|
||||
application.filesDir.absolutePath + "/generated.wav"
|
||||
|
||||
|
||||
val ok =
|
||||
audio.samples.isNotEmpty() && audio.save(
|
||||
filename
|
||||
)
|
||||
|
||||
if (ok) {
|
||||
withContext(Dispatchers.Main) {
|
||||
startEnabled = true
|
||||
playEnabled = true
|
||||
rtfText = RTF
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
}) {
|
||||
Text("Test")
|
||||
Text("Start")
|
||||
}
|
||||
|
||||
Button(
|
||||
modifier = Modifier.padding(20.dp),
|
||||
modifier = Modifier.padding(5.dp),
|
||||
enabled = playEnabled,
|
||||
onClick = {
|
||||
TtsEngine.speakerId = 0
|
||||
TtsEngine.speed = 1.0f
|
||||
testText = ""
|
||||
stopped = true
|
||||
track.pause()
|
||||
track.flush()
|
||||
onClickPlay()
|
||||
}) {
|
||||
Text("Reset")
|
||||
Text("Play")
|
||||
}
|
||||
|
||||
Button(
|
||||
modifier = Modifier.padding(5.dp),
|
||||
onClick = {
|
||||
onClickStop()
|
||||
startEnabled = true
|
||||
}) {
|
||||
Text("Stop")
|
||||
}
|
||||
}
|
||||
if (rtfText.isNotEmpty()) {
|
||||
Row {
|
||||
Text(rtfText)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -185,4 +282,63 @@ class MainActivity : ComponentActivity() {
|
||||
mediaPlayer?.release()
|
||||
mediaPlayer = null
|
||||
}
|
||||
|
||||
private fun onClickPlay() {
|
||||
val filename = application.filesDir.absolutePath + "/generated.wav"
|
||||
stopMediaPlayer()
|
||||
mediaPlayer = MediaPlayer.create(
|
||||
applicationContext,
|
||||
Uri.fromFile(File(filename))
|
||||
)
|
||||
mediaPlayer?.start()
|
||||
}
|
||||
|
||||
private fun onClickStop() {
|
||||
stopped = true
|
||||
track.pause()
|
||||
track.flush()
|
||||
|
||||
stopMediaPlayer()
|
||||
}
|
||||
|
||||
// this function is called from C++
|
||||
private fun callback(samples: FloatArray): Int {
|
||||
if (!stopped) {
|
||||
val samplesCopy = samples.copyOf()
|
||||
CoroutineScope(Dispatchers.IO).launch {
|
||||
samplesChannel.send(samplesCopy)
|
||||
}
|
||||
return 1
|
||||
} else {
|
||||
track.stop()
|
||||
Log.i(TAG, " return 0")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
private fun initAudioTrack() {
|
||||
val sampleRate = TtsEngine.tts!!.sampleRate()
|
||||
val bufLength = AudioTrack.getMinBufferSize(
|
||||
sampleRate,
|
||||
AudioFormat.CHANNEL_OUT_MONO,
|
||||
AudioFormat.ENCODING_PCM_FLOAT
|
||||
)
|
||||
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
|
||||
|
||||
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
|
||||
val format = AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.setSampleRate(sampleRate)
|
||||
.build()
|
||||
|
||||
track = AudioTrack(
|
||||
attr, format, bufLength, AudioTrack.MODE_STREAM,
|
||||
AudioManager.AUDIO_SESSION_ID_GENERATE
|
||||
)
|
||||
track.play()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +41,9 @@ object TtsEngine {
|
||||
|
||||
private var modelDir: String? = null
|
||||
private var modelName: String? = null
|
||||
private var acousticModelName: String? = null
|
||||
private var vocoder: String? = null
|
||||
private var acousticModelName: String? = null // for matcha tts
|
||||
private var vocoder: String? = null // for matcha tts
|
||||
private var voices: String? = null // for kokoro
|
||||
private var ruleFsts: String? = null
|
||||
private var ruleFars: String? = null
|
||||
private var lexicon: String? = null
|
||||
@@ -64,6 +65,10 @@ object TtsEngine {
|
||||
vocoder = null
|
||||
// For Matcha -- end
|
||||
|
||||
// For Kokoro -- begin
|
||||
voices = null
|
||||
// For Kokoro -- end
|
||||
|
||||
modelDir = null
|
||||
ruleFsts = null
|
||||
ruleFars = null
|
||||
@@ -139,6 +144,14 @@ object TtsEngine {
|
||||
// vocoder = "hifigan_v2.onnx"
|
||||
// dataDir = "matcha-icefall-en_US-ljspeech/espeak-ng-data"
|
||||
// lang = "eng"
|
||||
|
||||
// Example 9
|
||||
// kokoro-en-v0_19
|
||||
// modelDir = "kokoro-en-v0_19"
|
||||
// modelName = "model.onnx"
|
||||
// voices = "voices.bin"
|
||||
// dataDir = "kokoro-en-v0_19/espeak-ng-data"
|
||||
// lang = "eng"
|
||||
}
|
||||
|
||||
fun createTts(context: Context) {
|
||||
@@ -167,6 +180,7 @@ object TtsEngine {
|
||||
modelName = modelName ?: "",
|
||||
acousticModelName = acousticModelName ?: "",
|
||||
vocoder = vocoder ?: "",
|
||||
voices = voices ?: "",
|
||||
lexicon = lexicon ?: "",
|
||||
dataDir = dataDir ?: "",
|
||||
dictDir = dictDir ?: "",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">TTS Engine</string>
|
||||
<string name="app_name">TTS Engine: Next-gen Kaldi</string>
|
||||
</resources>
|
||||
Reference in New Issue
Block a user