diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 50caa49f..4cf88e15 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -1,7 +1,7 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager -import android.media.MediaPlayer +import android.media.* import android.net.Uri import android.os.Bundle import android.util.Log @@ -23,6 +23,10 @@ class MainActivity : AppCompatActivity() { private lateinit var generate: Button private lateinit var play: Button + // see + // https://developer.android.com/reference/kotlin/android/media/AudioTrack + private lateinit var track: AudioTrack + override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setContentView(R.layout.activity_main) @@ -31,6 +35,10 @@ class MainActivity : AppCompatActivity() { initTts() Log.i(TAG, "Finish initializing TTS") + Log.i(TAG, "Start to initialize AudioTrack") + initAudioTrack() + Log.i(TAG, "Finish initializing AudioTrack") + text = findViewById(R.id.text) sid = findViewById(R.id.sid) speed = findViewById(R.id.speed) @@ -51,6 +59,33 @@ class MainActivity : AppCompatActivity() { play.isEnabled = false } + private fun initAudioTrack() { + val sampleRate = tts.sampleRate() + val bufLength = (sampleRate * 0.1).toInt() + Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}") + + val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .setUsage(AudioAttributes.USAGE_MEDIA) + .build() + + val format = AudioFormat.Builder() + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) + .setSampleRate(sampleRate) + .build() + + track = AudioTrack( + attr, format, bufLength, AudioTrack.MODE_STREAM, + AudioManager.AUDIO_SESSION_ID_GENERATE + ) + track.play() + } + + // this function is called from C++ + private fun callback(samples: FloatArray) { + track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING) + } + private fun onClickGenerate() { val sidInt = sid.text.toString().toIntOrNull() if (sidInt == null || sidInt < 0) { @@ -79,16 +114,28 @@ class MainActivity : AppCompatActivity() { return } - play.isEnabled = false - val audio = tts.generate(text = textStr, sid = sidInt, speed = speedFloat) + track.pause() + track.flush() + track.play() - val filename = application.filesDir.absolutePath + "/generated.wav" - val ok = audio.samples.size > 0 && audio.save(filename) - if (ok) { - play.isEnabled = true - // Play automatically after generation - onClickPlay() - } + play.isEnabled = false + Thread { + val audio = tts.generateWithCallback( + text = textStr, + sid = sidInt, + speed = speedFloat, + callback = this::callback + ) + + val filename = application.filesDir.absolutePath + "/generated.wav" + val ok = audio.samples.size > 0 && audio.save(filename) + if (ok) { + runOnUiThread { + play.isEnabled = true + track.stop() + } + } + }.start() } private fun onClickPlay() { diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt index d08e803a..bb500440 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt @@ -54,6 +54,8 @@ class OfflineTts( } } + fun sampleRate() = getSampleRate(ptr) + fun generate( text: String, sid: Int = 0, @@ -66,6 +68,19 @@ class OfflineTts( ) } + fun generateWithCallback( + text: String, + sid: Int = 0, + speed: Float = 1.0f, + callback: (samples: FloatArray) -> Unit + ): GeneratedAudio { + var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback) + return GeneratedAudio( + samples = objArray[0] as FloatArray, + sampleRate = objArray[1] as Int + ) + } + fun allocate(assetManager: AssetManager? = null) { if (ptr == 0L) { if (assetManager != null) { @@ -97,6 +112,7 @@ class OfflineTts( ): Long private external fun delete(ptr: Long) + private external fun getSampleRate(ptr: Long): Int // The returned array has two entries: // - the first entry is an 1-D float array containing audio samples. @@ -109,6 +125,14 @@ class OfflineTts( speed: Float = 1.0f ): Array + external fun generateWithCallbackImpl( + ptr: Long, + text: String, + sid: Int = 0, + speed: Float = 1.0f, + callback: (samples: FloatArray) -> Unit + ): Array + companion object { init { System.loadLibrary("sherpa-onnx-jni") diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index 4d7b5ff6..2d6f9e6f 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -2,6 +2,10 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager +fun callback(samples: FloatArray): Unit { + println("callback got called with ${samples.size} samples"); +} + fun main() { testTts() testAsr() @@ -22,7 +26,7 @@ fun testTts() { ) ) val tts = OfflineTts(config=config) - val audio = tts.generate(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”") + val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback) audio.save(filename="test-en.wav") } diff --git a/scripts/apk/generate-tts-apk-script.py b/scripts/apk/generate-tts-apk-script.py index 41f2e6c2..25471c8a 100755 --- a/scripts/apk/generate-tts-apk-script.py +++ b/scripts/apk/generate-tts-apk-script.py @@ -172,57 +172,57 @@ def get_vits_models() -> List[TtsModel]: lang="zh", rule_fsts="vits-zh-aishell3/rule.fst", ), - TtsModel( - model_dir="vits-zh-hf-doom", - model_name="doom.onnx", - lang="zh", - rule_fsts="vits-zh-hf-doom/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-echo", - model_name="echo.onnx", - lang="zh", - rule_fsts="vits-zh-hf-echo/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-zenyatta", - model_name="zenyatta.onnx", - lang="zh", - rule_fsts="vits-zh-hf-zenyatta/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-abyssinvoker", - model_name="abyssinvoker.onnx", - lang="zh", - rule_fsts="vits-zh-hf-abyssinvoker/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-keqing", - model_name="keqing.onnx", - lang="zh", - rule_fsts="vits-zh-hf-keqing/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-eula", - model_name="eula.onnx", - lang="zh", - rule_fsts="vits-zh-hf-eula/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-bronya", - model_name="bronya.onnx", - lang="zh", - rule_fsts="vits-zh-hf-bronya/rule.fst", - ), - TtsModel( - model_dir="vits-zh-hf-theresa", - model_name="theresa.onnx", - lang="zh", - rule_fsts="vits-zh-hf-theresa/rule.fst", - ), + # TtsModel( + # model_dir="vits-zh-hf-doom", + # model_name="doom.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-doom/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-echo", + # model_name="echo.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-echo/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-zenyatta", + # model_name="zenyatta.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-zenyatta/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-abyssinvoker", + # model_name="abyssinvoker.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-abyssinvoker/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-keqing", + # model_name="keqing.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-keqing/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-eula", + # model_name="eula.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-eula/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-bronya", + # model_name="bronya.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-bronya/rule.fst", + # ), + # TtsModel( + # model_dir="vits-zh-hf-theresa", + # model_name="theresa.onnx", + # lang="zh", + # rule_fsts="vits-zh-hf-theresa/rule.fst", + # ), # English (US) TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"), - TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"), + # TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"), # fmt: on ] @@ -238,8 +238,8 @@ def main(): template = environment.from_string(s) d = dict() - # all_model_list = get_vits_models() - all_model_list = get_piper_models() + all_model_list = get_vits_models() + all_model_list += get_piper_models() all_model_list += get_coqui_models() num_models = len(all_model_list) diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index a3c7952c..0e2c3794 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -11,13 +11,15 @@ // android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include #include "jni.h" // NOLINT +#include +#include #include #include + #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif -#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer.h" @@ -502,11 +504,14 @@ class SherpaOnnxOfflineTts { explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config) : tts_(config) {} - GeneratedAudio Generate(const std::string &text, int64_t sid = 0, - float speed = 1.0) const { - return tts_.Generate(text, sid, speed); + GeneratedAudio Generate( + const std::string &text, int64_t sid = 0, float speed = 1.0, + std::function callback = nullptr) const { + return tts_.Generate(text, sid, speed, callback); } + int32_t SampleRate() const { return tts_.SampleRate(); } + private: OfflineTts tts_; }; @@ -628,6 +633,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( delete reinterpret_cast(ptr); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr) + ->SampleRate(); +} + // see // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables static jobject NewInteger(JNIEnv *env, int32_t value) { @@ -663,6 +675,43 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, return obj_arr; } +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid, + jfloat speed, jobject callback) { + const char *p_text = env->GetStringUTFChars(text, nullptr); + SHERPA_ONNX_LOGE("string is: %s", p_text); + + std::function callback_wrapper = + [env, callback](const float *samples, int32_t n) { + jclass cls = env->GetObjectClass(callback); + jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); + + jfloatArray samples_arr = env->NewFloatArray(n); + env->SetFloatArrayRegion(samples_arr, 0, n, samples); + env->CallVoidMethod(callback, mid, samples_arr); + }; + + auto audio = + reinterpret_cast(ptr)->Generate( + p_text, sid, speed, callback_wrapper); + + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), + audio.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); + + env->ReleaseStringUTFChars(text, p_text); + + return obj_arr; +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,