This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/jni/offline-tts.cc
2025-01-17 17:36:13 +08:00

326 lines
12 KiB
C++

// sherpa-onnx/jni/offline-tts.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
OfflineTtsConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
// vits
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
jclass vits_cls = env->GetObjectClass(vits);
fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(vits, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.lexicon = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.dict_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "noiseScale", "F");
ans.model.vits.noise_scale = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "noiseScaleW", "F");
ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
// matcha
fid = env->GetFieldID(model_config_cls, "matcha",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
jobject matcha = env->GetObjectField(model, fid);
jclass matcha_cls = env->GetObjectClass(matcha);
fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.acoustic_model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.vocoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.lexicon = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.dict_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
// kokoro
fid = env->GetFieldID(model_config_cls, "kokoro",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;");
jobject kokoro = env->GetObjectField(model, fid);
jclass kokoro_cls = env->GetObjectClass(kokoro);
fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.voices = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "lengthScale", "F");
ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFsts
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFars
fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxNumSentences", "I");
ans.max_num_sentences = env->GetIntField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto tts = new sherpa_onnx::OfflineTts(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto tts = new sherpa_onnx::OfflineTts(config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->SampleRate();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
jlong ptr, jstring text,
jint sid, jfloat speed) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
p_text, sid, speed);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
jfloat speed, jobject callback) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
std::function<int32_t(const float *, int32_t, float)> callback_wrapper =
[env, callback](const float *samples, int32_t n,
float /*progress*/) -> int {
jclass cls = env->GetObjectClass(callback);
#if 0
// this block is for debugging only
// see also
// https://jnjosh.com/posts/kotlinfromcpp/
jmethodID classMethodId =
env->GetMethodID(cls, "getClass", "()Ljava/lang/Class;");
jobject klassObj = env->CallObjectMethod(callback, classMethodId);
auto klassObject = env->GetObjectClass(klassObj);
auto nameMethodId =
env->GetMethodID(klassObject, "getName", "()Ljava/lang/String;");
jstring classString =
(jstring)env->CallObjectMethod(klassObj, nameMethodId);
auto className = env->GetStringUTFChars(classString, NULL);
SHERPA_ONNX_LOGE("name is: %s", className);
env->ReleaseStringUTFChars(classString, className);
#endif
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)Ljava/lang/Integer;");
if (mid == nullptr) {
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
return 1;
}
jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
jobject should_continue = env->CallObjectMethod(callback, mid, samples_arr);
jclass jklass = env->GetObjectClass(should_continue);
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
return env->CallIntMethod(should_continue, int_value_mid);
};
auto tts = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
auto audio = tts->Generate(p_text, sid, speed, callback_wrapper);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}