Support playing as it is generating for Android (#477)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.media.MediaPlayer
|
||||
import android.media.*
|
||||
import android.net.Uri
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
@@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() {
|
||||
private lateinit var generate: Button
|
||||
private lateinit var play: Button
|
||||
|
||||
// see
|
||||
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
|
||||
private lateinit var track: AudioTrack
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContentView(R.layout.activity_main)
|
||||
@@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() {
|
||||
initTts()
|
||||
Log.i(TAG, "Finish initializing TTS")
|
||||
|
||||
Log.i(TAG, "Start to initialize AudioTrack")
|
||||
initAudioTrack()
|
||||
Log.i(TAG, "Finish initializing AudioTrack")
|
||||
|
||||
text = findViewById(R.id.text)
|
||||
sid = findViewById(R.id.sid)
|
||||
speed = findViewById(R.id.speed)
|
||||
@@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() {
|
||||
play.isEnabled = false
|
||||
}
|
||||
|
||||
private fun initAudioTrack() {
|
||||
val sampleRate = tts.sampleRate()
|
||||
val bufLength = (sampleRate * 0.1).toInt()
|
||||
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
|
||||
|
||||
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
|
||||
val format = AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.setSampleRate(sampleRate)
|
||||
.build()
|
||||
|
||||
track = AudioTrack(
|
||||
attr, format, bufLength, AudioTrack.MODE_STREAM,
|
||||
AudioManager.AUDIO_SESSION_ID_GENERATE
|
||||
)
|
||||
track.play()
|
||||
}
|
||||
|
||||
// this function is called from C++
|
||||
private fun callback(samples: FloatArray) {
|
||||
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
|
||||
}
|
||||
|
||||
private fun onClickGenerate() {
|
||||
val sidInt = sid.text.toString().toIntOrNull()
|
||||
if (sidInt == null || sidInt < 0) {
|
||||
@@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() {
|
||||
return
|
||||
}
|
||||
|
||||
play.isEnabled = false
|
||||
val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat)
|
||||
track.pause()
|
||||
track.flush()
|
||||
track.play()
|
||||
|
||||
val filename = application.filesDir.absolutePath + "/generated.wav"
|
||||
val ok = audio.samples.size > 0 && audio.save(filename)
|
||||
if (ok) {
|
||||
play.isEnabled = true
|
||||
// Play automatically after generation
|
||||
onClickPlay()
|
||||
}
|
||||
play.isEnabled = false
|
||||
Thread {
|
||||
val audio = tts.generateWithCallback(
|
||||
text = textStr,
|
||||
sid = sidInt,
|
||||
speed = speedFloat,
|
||||
callback = this::callback
|
||||
)
|
||||
|
||||
val filename = application.filesDir.absolutePath + "/generated.wav"
|
||||
val ok = audio.samples.size > 0 && audio.save(filename)
|
||||
if (ok) {
|
||||
runOnUiThread {
|
||||
play.isEnabled = true
|
||||
track.stop()
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
private fun onClickPlay() {
|
||||
|
||||
@@ -54,6 +54,8 @@ class OfflineTts(
|
||||
}
|
||||
}
|
||||
|
||||
fun sampleRate() = getSampleRate(ptr)
|
||||
|
||||
fun generate(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
@@ -66,6 +68,19 @@ class OfflineTts(
|
||||
)
|
||||
}
|
||||
|
||||
fun generateWithCallback(
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Unit
|
||||
): GeneratedAudio {
|
||||
var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback)
|
||||
return GeneratedAudio(
|
||||
samples = objArray[0] as FloatArray,
|
||||
sampleRate = objArray[1] as Int
|
||||
)
|
||||
}
|
||||
|
||||
fun allocate(assetManager: AssetManager? = null) {
|
||||
if (ptr == 0L) {
|
||||
if (assetManager != null) {
|
||||
@@ -97,6 +112,7 @@ class OfflineTts(
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
private external fun getSampleRate(ptr: Long): Int
|
||||
|
||||
// The returned array has two entries:
|
||||
// - the first entry is an 1-D float array containing audio samples.
|
||||
@@ -109,6 +125,14 @@ class OfflineTts(
|
||||
speed: Float = 1.0f
|
||||
): Array<Any>
|
||||
|
||||
external fun generateWithCallbackImpl(
|
||||
ptr: Long,
|
||||
text: String,
|
||||
sid: Int = 0,
|
||||
speed: Float = 1.0f,
|
||||
callback: (samples: FloatArray) -> Unit
|
||||
): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
|
||||
@@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
fun callback(samples: FloatArray): Unit {
|
||||
println("callback got called with ${samples.size} samples");
|
||||
}
|
||||
|
||||
fun main() {
|
||||
testTts()
|
||||
testAsr()
|
||||
@@ -22,7 +26,7 @@ fun testTts() {
|
||||
)
|
||||
)
|
||||
val tts = OfflineTts(config=config)
|
||||
val audio = tts.generate(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”")
|
||||
val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback)
|
||||
audio.save(filename="test-en.wav")
|
||||
}
|
||||
|
||||
|
||||
@@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]:
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-aishell3/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-doom",
|
||||
model_name="doom.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-doom/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-echo",
|
||||
model_name="echo.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-echo/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-zenyatta",
|
||||
model_name="zenyatta.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-zenyatta/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-abyssinvoker",
|
||||
model_name="abyssinvoker.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-keqing",
|
||||
model_name="keqing.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-keqing/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-eula",
|
||||
model_name="eula.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-eula/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-bronya",
|
||||
model_name="bronya.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-bronya/rule.fst",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-theresa",
|
||||
model_name="theresa.onnx",
|
||||
lang="zh",
|
||||
rule_fsts="vits-zh-hf-theresa/rule.fst",
|
||||
),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-doom",
|
||||
# model_name="doom.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-doom/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-echo",
|
||||
# model_name="echo.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-echo/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-zenyatta",
|
||||
# model_name="zenyatta.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-zenyatta/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-abyssinvoker",
|
||||
# model_name="abyssinvoker.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-keqing",
|
||||
# model_name="keqing.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-keqing/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-eula",
|
||||
# model_name="eula.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-eula/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-bronya",
|
||||
# model_name="bronya.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-bronya/rule.fst",
|
||||
# ),
|
||||
# TtsModel(
|
||||
# model_dir="vits-zh-hf-theresa",
|
||||
# model_name="theresa.onnx",
|
||||
# lang="zh",
|
||||
# rule_fsts="vits-zh-hf-theresa/rule.fst",
|
||||
# ),
|
||||
# English (US)
|
||||
TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
|
||||
TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
|
||||
# TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
|
||||
# fmt: on
|
||||
]
|
||||
|
||||
@@ -238,8 +238,8 @@ def main():
|
||||
template = environment.from_string(s)
|
||||
d = dict()
|
||||
|
||||
# all_model_list = get_vits_models()
|
||||
all_model_list = get_piper_models()
|
||||
all_model_list = get_vits_models()
|
||||
all_model_list += get_piper_models()
|
||||
all_model_list += get_coqui_models()
|
||||
|
||||
num_models = len(all_model_list)
|
||||
|
||||
@@ -11,13 +11,15 @@
|
||||
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
#include "jni.h" // NOLINT
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <strstream>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
#include <fstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts {
|
||||
explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
|
||||
: tts_(config) {}
|
||||
|
||||
GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
|
||||
float speed = 1.0) const {
|
||||
return tts_.Generate(text, sid, speed);
|
||||
GeneratedAudio Generate(
|
||||
const std::string &text, int64_t sid = 0, float speed = 1.0,
|
||||
std::function<void(const float *, int32_t)> callback = nullptr) const {
|
||||
return tts_.Generate(text, sid, speed, callback);
|
||||
}
|
||||
|
||||
int32_t SampleRate() const { return tts_.SampleRate(); }
|
||||
|
||||
private:
|
||||
OfflineTts tts_;
|
||||
};
|
||||
@@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
|
||||
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
|
||||
->SampleRate();
|
||||
}
|
||||
|
||||
// see
|
||||
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
|
||||
static jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
@@ -663,6 +675,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
|
||||
jfloat speed, jobject callback) {
|
||||
const char *p_text = env->GetStringUTFChars(text, nullptr);
|
||||
SHERPA_ONNX_LOGE("string is: %s", p_text);
|
||||
|
||||
std::function<void(const float *, int32_t)> callback_wrapper =
|
||||
[env, callback](const float *samples, int32_t n) {
|
||||
jclass cls = env->GetObjectClass(callback);
|
||||
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
|
||||
|
||||
jfloatArray samples_arr = env->NewFloatArray(n);
|
||||
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
|
||||
env->CallVoidMethod(callback, mid, samples_arr);
|
||||
};
|
||||
|
||||
auto audio =
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
|
||||
p_text, sid, speed, callback_wrapper);
|
||||
|
||||
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
|
||||
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
|
||||
audio.samples.data());
|
||||
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
2, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
|
||||
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
|
||||
|
||||
env->ReleaseStringUTFChars(text, p_text);
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
|
||||
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
|
||||
|
||||
Reference in New Issue
Block a user