Support playing as it is generating for Android (#477)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.media.MediaPlayer
|
||||
import android.media.*
|
||||
import android.net.Uri
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
@@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() {
|
||||
private lateinit var generate: Button
|
||||
private lateinit var play: Button
|
||||
|
||||
// see
|
||||
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
|
||||
private lateinit var track: AudioTrack
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContentView(R.layout.activity_main)
|
||||
@@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() {
|
||||
initTts()
|
||||
Log.i(TAG, "Finish initializing TTS")
|
||||
|
||||
Log.i(TAG, "Start to initialize AudioTrack")
|
||||
initAudioTrack()
|
||||
Log.i(TAG, "Finish initializing AudioTrack")
|
||||
|
||||
text = findViewById(R.id.text)
|
||||
sid = findViewById(R.id.sid)
|
||||
speed = findViewById(R.id.speed)
|
||||
@@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() {
|
||||
play.isEnabled = false
|
||||
}
|
||||
|
||||
private fun initAudioTrack() {
|
||||
val sampleRate = tts.sampleRate()
|
||||
val bufLength = (sampleRate * 0.1).toInt()
|
||||
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
|
||||
|
||||
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
|
||||
val format = AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.setSampleRate(sampleRate)
|
||||
.build()
|
||||
|
||||
track = AudioTrack(
|
||||
attr, format, bufLength, AudioTrack.MODE_STREAM,
|
||||
AudioManager.AUDIO_SESSION_ID_GENERATE
|
||||
)
|
||||
track.play()
|
||||
}
|
||||
|
||||
// this function is called from C++
|
||||
private fun callback(samples: FloatArray) {
|
||||
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
|
||||
}
|
||||
|
||||
private fun onClickGenerate() {
|
||||
val sidInt = sid.text.toString().toIntOrNull()
|
||||
if (sidInt == null || sidInt < 0) {
|
||||
@@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() {
|
||||
return
|
||||
}
|
||||
|
||||
play.isEnabled = false
|
||||
val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat)
|
||||
track.pause()
|
||||
track.flush()
|
||||
track.play()
|
||||
|
||||
val filename = application.filesDir.absolutePath + "/generated.wav"
|
||||
val ok = audio.samples.size > 0 && audio.save(filename)
|
||||
if (ok) {
|
||||
play.isEnabled = true
|
||||
// Play automatically after generation
|
||||
onClickPlay()
|
||||
}
|
||||
play.isEnabled = false
|
||||
Thread {
|
||||
val audio = tts.generateWithCallback(
|
||||
text = textStr,
|
||||
sid = sidInt,
|
||||
speed = speedFloat,
|
||||
callback = this::callback
|
||||
)
|
||||
|
||||
val filename = application.filesDir.absolutePath + "/generated.wav"
|
||||
val ok = audio.samples.size > 0 && audio.save(filename)
|
||||
if (ok) {
|
||||
runOnUiThread {
|
||||
play.isEnabled = true
|
||||
track.stop()
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
private fun onClickPlay() {
|
||||
|
||||
@@ -54,6 +54,8 @@ class OfflineTts(
|
||||
}
|
||||
}
|
||||
|
||||
fun sampleRate() = getSampleRate(ptr)
|
||||
|
||||
fun generate(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
@@ -66,6 +68,19 @@ class OfflineTts(
|
||||
)
|
||||
}
|
||||
|
||||
fun generateWithCallback(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Unit
|
||||
): GeneratedAudio {
|
||||
var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
|
||||
return GeneratedAudio(
|
||||
samples = objArray[0] as FloatArray,
|
||||
sampleRate = objArray[1] as Int
|
||||
)
|
||||
}
|
||||
|
||||
fun allocate(assetManager: AssetManager? = null) {
|
||||
if (ptr == 0L) {
|
||||
if (assetManager != null) {
|
||||
@@ -97,6 +112,7 @@ class OfflineTts(
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
private external fun getSampleRate(ptr: Long): Int
|
||||
|
||||
// The returned array has two entries:
|
||||
// - the first entry is an 1-D float array containing audio samples.
|
||||
@@ -109,6 +125,14 @@ class OfflineTts(
|
||||
speed: Float = 1.0f
|
||||
): Array<Any>
|
||||
|
||||
external fun generateWithCallbackImpl(
|
||||
ptr: Long,
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Unit
|
||||
): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
|
||||
Reference in New Issue
Block a user