Refactor TTS Android code to support jieba for Chinese TTS models (#800)

This commit is contained in:
Fangjun Kuang
2024-04-22 17:21:05 +08:00
committed by GitHub
parent 494cb5c733
commit 7f3b9ffe5d
40 changed files with 352 additions and 285 deletions

View File

@@ -32,7 +32,7 @@ bool AudioTaggingModelConfig::Validate() const {
}
if (!ced.empty() && !FileExists(ced)) {
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
SHERPA_ONNX_LOGE("CED model file '%s' does not exist", ced.c_str());
return false;
}

View File

@@ -48,7 +48,7 @@ bool AudioTaggingConfig::Validate() const {
}
if (!FileExists(labels)) {
SHERPA_ONNX_LOGE("--labels %s does not exist", labels.c_str());
SHERPA_ONNX_LOGE("--labels '%s' does not exist", labels.c_str());
return false;
}

View File

@@ -7,7 +7,7 @@
#include <fstream>
#include <string>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
@@ -17,7 +17,7 @@ bool FileExists(const std::string &filename) {
void AssertFileExists(const std::string &filename) {
if (!FileExists(filename)) {
SHERPA_ONNX_LOG(FATAL) << filename << " does not exist!";
SHERPA_ONNX_LOGE("filename '%s' does not exist", filename.c_str());
exit(-1);
}
}

View File

@@ -146,6 +146,14 @@ class JiebaLexicon::Impl {
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
token2id_[p.second] = token2id_[p.first];
}
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
token2id_[p.first] = token2id_[p.second];
}
}
if (!token2id_.count("") && token2id_.count("")) {
token2id_[""] = token2id_[""];
}
}

View File

@@ -101,7 +101,8 @@ bool KeywordSpotterConfig::Validate() const {
// Solution: take keyword_file variable is directly
// parsed as a string of keywords
if (!std::ifstream(keywords_file.c_str()).good()) {
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.",
keywords_file.c_str());
return false;
}
#endif

View File

@@ -34,7 +34,7 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
bool OfflineCtcFstDecoderConfig::Validate() const {
if (!graph.empty() && !FileExists(graph)) {
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str());
return false;
}
return true;

View File

@@ -22,7 +22,7 @@ void OfflineLMConfig::Register(ParseOptions *po) {
bool OfflineLMConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str());
return false;
}

View File

@@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("NeMo model: '%s' does not exist", model.c_str());
return false;
}

View File

@@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
bool OfflineParaformerModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("Paraformer model '%s' does not exist", model.c_str());
return false;
}

View File

@@ -18,19 +18,19 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist",
encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist",
decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
SHERPA_ONNX_LOGE("transducer joiner: '%s' does not exist",
joiner_filename.c_str());
return false;
}

View File

@@ -35,7 +35,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("--vits-model: '%s' does not exist", model.c_str());
return false;
}
@@ -45,31 +45,31 @@ bool OfflineTtsVitsModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("--vits-tokens: '%s' does not exist", tokens.c_str());
return false;
}
if (!data_dir.empty()) {
if (!FileExists(data_dir + "/phontab")) {
SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test",
SHERPA_ONNX_LOGE("'%s/phontab' does not exist. Skipping test",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/phonindex")) {
SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test",
SHERPA_ONNX_LOGE("'%s/phonindex' does not exist. Skipping test",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/phondata")) {
SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test",
SHERPA_ONNX_LOGE("'%s/phondata' does not exist. Skipping test",
data_dir.c_str());
return false;
}
if (!FileExists(data_dir + "/intonations")) {
SHERPA_ONNX_LOGE("%s/intonations does not exist.", data_dir.c_str());
SHERPA_ONNX_LOGE("'%s/intonations' does not exist.", data_dir.c_str());
return false;
}
}
@@ -82,7 +82,8 @@ bool OfflineTtsVitsModelConfig::Validate() const {
for (const auto &f : required_files) {
if (!FileExists(dict_dir + "/" + f)) {
SHERPA_ONNX_LOGE("%s/%s does not exist.", data_dir.c_str(), f.c_str());
SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(),
f.c_str());
return false;
}
}

View File

@@ -42,7 +42,7 @@ bool OfflineTtsConfig::Validate() const {
SplitStringToVector(rule_fsts, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst %s does not exist. ", f.c_str());
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
@@ -53,7 +53,7 @@ bool OfflineTtsConfig::Validate() const {
SplitStringToVector(rule_fars, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule far %s does not exist. ", f.c_str());
SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
return false;
}
}

View File

@@ -18,7 +18,7 @@ void OfflineWenetCtcModelConfig::Register(ParseOptions *po) {
bool OfflineWenetCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("WeNet model: '%s' does not exist", model.c_str());
return false;
}

View File

@@ -48,7 +48,8 @@ bool OfflineWhisperModelConfig::Validate() const {
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist",
encoder.c_str());
return false;
}
@@ -58,7 +59,8 @@ bool OfflineWhisperModelConfig::Validate() const {
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist",
decoder.c_str());
return false;
}

View File

@@ -21,7 +21,7 @@ bool OfflineZipformerAudioTaggingModelConfig::Validate() const {
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--zipformer-model: %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("--zipformer-model: '%s' does not exist", model.c_str());
return false;
}

View File

@@ -15,7 +15,7 @@ void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) {
bool OfflineZipformerCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist",
SHERPA_ONNX_LOGE("zipformer CTC model file '%s' does not exist",
model.c_str());
return false;
}

View File

@@ -31,7 +31,7 @@ void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
bool OnlineCtcFstDecoderConfig::Validate() const {
if (!graph.empty() && !FileExists(graph)) {
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str());
return false;
}
return true;

View File

@@ -22,7 +22,7 @@ void OnlineLMConfig::Register(ParseOptions *po) {
bool OnlineLMConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str());
return false;
}

View File

@@ -45,7 +45,7 @@ bool OnlineModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
return false;
}

View File

@@ -18,12 +18,12 @@ void OnlineParaformerModelConfig::Register(ParseOptions *po) {
bool OnlineParaformerModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
SHERPA_ONNX_LOGE("Paraformer encoder '%s' does not exist", encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
SHERPA_ONNX_LOGE("Paraformer decoder '%s' does not exist", decoder.c_str());
return false;
}

View File

@@ -18,17 +18,19 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) {
bool OnlineTransducerModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", encoder.c_str());
SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist",
encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", decoder.c_str());
SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist",
decoder.c_str());
return false;
}
if (!FileExists(joiner)) {
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner.c_str());
SHERPA_ONNX_LOGE("joiner: '%s' does not exist", joiner.c_str());
return false;
}

View File

@@ -21,7 +21,7 @@ void OnlineWenetCtcModelConfig::Register(ParseOptions *po) {
bool OnlineWenetCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("WeNet CTC model %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("WeNet CTC model '%s' does not exist", model.c_str());
return false;
}

View File

@@ -22,7 +22,8 @@ bool OnlineZipformer2CtcModelConfig::Validate() const {
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("--zipformer2-ctc-model '%s' does not exist",
model.c_str());
return false;
}

View File

@@ -44,7 +44,8 @@ bool SileroVadModelConfig::Validate() const {
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("Silero vad model file %s does not exist", model.c_str());
SHERPA_ONNX_LOGE("Silero vad model file '%s' does not exist",
model.c_str());
return false;
}

View File

@@ -31,7 +31,7 @@ bool SpeakerEmbeddingExtractorConfig::Validate() const {
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist",
model.c_str());
return false;
}

View File

@@ -43,7 +43,8 @@ bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist",
encoder.c_str());
return false;
}
@@ -53,7 +54,8 @@ bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist",
decoder.c_str());
return false;
}

View File

@@ -9,11 +9,20 @@ if(NOT DEFINED ANDROID_ABI)
include_directories($ENV{JAVA_HOME}/include/darwin)
endif()
add_library(sherpa-onnx-jni
set(sources
audio-tagging.cc
jni.cc
offline-stream.cc
spoken-language-identification.cc
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
offline-tts.cc
)
endif()
add_library(sherpa-onnx-jni ${sources})
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)

View File

@@ -24,10 +24,6 @@
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
#if SHERPA_ONNX_ENABLE_TTS == 1
#include "sherpa-onnx/csrc/offline-tts.h"
#endif
namespace sherpa_onnx {
class SherpaOnnx {
@@ -775,113 +771,6 @@ static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
return ans;
}
#if SHERPA_ONNX_ENABLE_TTS == 1
class SherpaOnnxOfflineTts {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxOfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config)
: tts_(mgr, config) {}
#endif
explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
: tts_(config) {}
GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0,
std::function<void(const float *, int32_t, float)>
callback = nullptr) const {
return tts_.Generate(text, sid, speed, callback);
}
int32_t SampleRate() const { return tts_.SampleRate(); }
int32_t NumSpeakers() const { return tts_.NumSpeakers(); }
private:
OfflineTts tts_;
};
static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
OfflineTtsConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
jclass vits_cls = env->GetObjectClass(vits);
fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(vits, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.lexicon = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "noiseScale", "F");
ans.model.vits.noise_scale = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "noiseScaleW", "F");
ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFsts
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFars
fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxNumSentences", "I");
ans.max_num_sentences = env->GetIntField(config, fid);
return ans;
}
#endif
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
@@ -1226,128 +1115,6 @@ jobject NewFloat(JNIEnv *env, float value) {
return env->NewObject(cls, constructor, value);
}
#if SHERPA_ONNX_ENABLE_TTS == 1
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
->SampleRate();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)
->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
jlong ptr, jstring text,
jint sid, jfloat speed) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
auto audio =
reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
p_text, sid, speed);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
jfloat speed, jobject callback) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
std::function<void(const float *, int32_t, float)> callback_wrapper =
[env, callback](const float *samples, int32_t n, float /*p*/) {
jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
env->CallVoidMethod(callback, mid, samples_arr);
};
auto audio =
reinterpret_cast<sherpa_onnx::SherpaOnnxOfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}
#endif
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,

View File

@@ -0,0 +1,215 @@
// sherpa-onnx/jni/offline-tts.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
OfflineTtsConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
jclass vits_cls = env->GetObjectClass(vits);
fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(vits, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.lexicon = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(vits, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.vits.dict_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(vits_cls, "noiseScale", "F");
ans.model.vits.noise_scale = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "noiseScaleW", "F");
ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid);
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFsts
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFars
fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxNumSentences", "I");
ans.max_num_sentences = env->GetIntField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newForAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto tts = new sherpa_onnx::OfflineTts(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto tts = new sherpa_onnx::OfflineTts(config);
return (jlong)tts;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->SampleRate();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/,
jlong ptr, jstring text,
jint sid, jfloat speed) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
p_text, sid, speed);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid,
jfloat speed, jobject callback) {
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
std::function<void(const float *, int32_t, float)> callback_wrapper =
[env, callback](const float *samples, int32_t n, float /*progress*/) {
jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
env->CallVoidMethod(callback, mid, samples_arr);
};
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
audio.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, samples_arr);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate));
env->ReleaseStringUTFChars(text, p_text);
return obj_arr;
}