Add Koltin and Java API for Kokoro TTS models (#1728)
This commit is contained in:
@@ -35,6 +35,7 @@ java_files += OfflineRecognizerResult.java
|
||||
java_files += OfflineStream.java
|
||||
java_files += OfflineRecognizer.java
|
||||
|
||||
java_files += OfflineTtsKokoroModelConfig.java
|
||||
java_files += OfflineTtsMatchaModelConfig.java
|
||||
java_files += OfflineTtsVitsModelConfig.java
|
||||
java_files += OfflineTtsModelConfig.java
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
// Copyright 2025 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OfflineTtsKokoroModelConfig {
|
||||
private final String model;
|
||||
private final String voices;
|
||||
private final String tokens;
|
||||
private final String dataDir;
|
||||
private final float lengthScale;
|
||||
|
||||
private OfflineTtsKokoroModelConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
this.voices = builder.voices;
|
||||
this.tokens = builder.tokens;
|
||||
this.dataDir = builder.dataDir;
|
||||
this.lengthScale = builder.lengthScale;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public String getVoices() {
|
||||
return voices;
|
||||
}
|
||||
|
||||
public String getTokens() {
|
||||
return tokens;
|
||||
}
|
||||
|
||||
public String getDataDir() {
|
||||
return dataDir;
|
||||
}
|
||||
|
||||
public float getLengthScale() {
|
||||
return lengthScale;
|
||||
}
|
||||
|
||||
|
||||
public static class Builder {
|
||||
private String model = "";
|
||||
private String voices = "";
|
||||
private String tokens = "";
|
||||
private String dataDir = "";
|
||||
private float lengthScale = 1.0f;
|
||||
|
||||
public OfflineTtsKokoroModelConfig build() {
|
||||
return new OfflineTtsKokoroModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setVoices(String voices) {
|
||||
this.voices = voices;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTokens(String tokens) {
|
||||
this.tokens = tokens;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDataDir(String dataDir) {
|
||||
this.dataDir = dataDir;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLengthScale(float lengthScale) {
|
||||
this.lengthScale = lengthScale;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx;
|
||||
public class OfflineTtsModelConfig {
|
||||
private final OfflineTtsVitsModelConfig vits;
|
||||
private final OfflineTtsMatchaModelConfig matcha;
|
||||
private final OfflineTtsKokoroModelConfig kokoro;
|
||||
private final int numThreads;
|
||||
private final boolean debug;
|
||||
private final String provider;
|
||||
@@ -12,6 +13,7 @@ public class OfflineTtsModelConfig {
|
||||
private OfflineTtsModelConfig(Builder builder) {
|
||||
this.vits = builder.vits;
|
||||
this.matcha = builder.matcha;
|
||||
this.kokoro = builder.kokoro;
|
||||
this.numThreads = builder.numThreads;
|
||||
this.debug = builder.debug;
|
||||
this.provider = builder.provider;
|
||||
@@ -29,9 +31,14 @@ public class OfflineTtsModelConfig {
|
||||
return matcha;
|
||||
}
|
||||
|
||||
public OfflineTtsKokoroModelConfig getKokoro() {
|
||||
return kokoro;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private OfflineTtsVitsModelConfig vits = OfflineTtsVitsModelConfig.builder().build();
|
||||
private OfflineTtsMatchaModelConfig matcha = OfflineTtsMatchaModelConfig.builder().build();
|
||||
private OfflineTtsKokoroModelConfig kokoro = OfflineTtsKokoroModelConfig.builder().build();
|
||||
private int numThreads = 1;
|
||||
private boolean debug = true;
|
||||
private String provider = "cpu";
|
||||
@@ -50,6 +57,11 @@ public class OfflineTtsModelConfig {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setKokoro(OfflineTtsKokoroModelConfig kokoro) {
|
||||
this.kokoro = kokoro;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumThreads(int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
return this;
|
||||
|
||||
@@ -113,6 +113,39 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
|
||||
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
|
||||
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
|
||||
|
||||
// kokoro
|
||||
fid = env->GetFieldID(model_config_cls, "kokoro",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;");
|
||||
jobject kokoro = env->GetObjectField(model, fid);
|
||||
jclass kokoro_cls = env->GetObjectClass(kokoro);
|
||||
|
||||
fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(kokoro, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.kokoro.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(kokoro, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.kokoro.voices = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(kokoro, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.kokoro.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(kokoro, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.kokoro.data_dir = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(kokoro_cls, "lengthScale", "F");
|
||||
ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model.num_threads = env->GetIntField(model, fid);
|
||||
|
||||
@@ -273,8 +306,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
|
||||
return env->CallIntMethod(should_continue, int_value_mid);
|
||||
};
|
||||
|
||||
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
|
||||
p_text, sid, speed, callback_wrapper);
|
||||
auto tts = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
|
||||
auto audio = tts->Generate(p_text, sid, speed, callback_wrapper);
|
||||
|
||||
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
|
||||
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),
|
||||
|
||||
@@ -25,9 +25,18 @@ data class OfflineTtsMatchaModelConfig(
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
data class OfflineTtsKokoroModelConfig(
|
||||
var model: String = "",
|
||||
var voices: String = "",
|
||||
var tokens: String = "",
|
||||
var dataDir: String = "",
|
||||
var lengthScale: Float = 1.0f,
|
||||
)
|
||||
|
||||
data class OfflineTtsModelConfig(
|
||||
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
|
||||
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
|
||||
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
@@ -176,12 +185,32 @@ fun getOfflineTtsConfig(
|
||||
modelName: String, // for VITS
|
||||
acousticModelName: String, // for Matcha
|
||||
vocoder: String, // for Matcha
|
||||
voices: String, // for Kokoro
|
||||
lexicon: String,
|
||||
dataDir: String,
|
||||
dictDir: String,
|
||||
ruleFsts: String,
|
||||
ruleFars: String
|
||||
ruleFars: String,
|
||||
numThreads: Int? = null
|
||||
): OfflineTtsConfig {
|
||||
// For Matcha TTS, please set
|
||||
// acousticModelName, vocoder
|
||||
|
||||
// For Kokoro TTS, please set
|
||||
// modelName, voices
|
||||
|
||||
// For VITS, please set
|
||||
// modelName
|
||||
|
||||
val numberOfThreads = if (numThreads != null) {
|
||||
numThreads
|
||||
} else if (voices.isNotEmpty()) {
|
||||
// for Kokoro TTS models, we use more threads
|
||||
4
|
||||
} else {
|
||||
2
|
||||
}
|
||||
|
||||
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
|
||||
throw IllegalArgumentException("Please specify a TTS model")
|
||||
}
|
||||
@@ -193,7 +222,8 @@ fun getOfflineTtsConfig(
|
||||
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
|
||||
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
|
||||
}
|
||||
val vits = if (modelName.isNotEmpty()) {
|
||||
|
||||
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
|
||||
OfflineTtsVitsModelConfig(
|
||||
model = "$modelDir/$modelName",
|
||||
lexicon = "$modelDir/$lexicon",
|
||||
@@ -218,11 +248,23 @@ fun getOfflineTtsConfig(
|
||||
OfflineTtsMatchaModelConfig()
|
||||
}
|
||||
|
||||
val kokoro = if (voices.isNotEmpty()) {
|
||||
OfflineTtsKokoroModelConfig(
|
||||
model = "$modelDir/$modelName",
|
||||
voices = "$modelDir/$voices",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
dataDir = dataDir,
|
||||
)
|
||||
} else {
|
||||
OfflineTtsKokoroModelConfig()
|
||||
}
|
||||
|
||||
return OfflineTtsConfig(
|
||||
model = OfflineTtsModelConfig(
|
||||
vits = vits,
|
||||
matcha = matcha,
|
||||
numThreads = 2,
|
||||
kokoro = kokoro,
|
||||
numThreads = numberOfThreads,
|
||||
debug = true,
|
||||
provider = "cpu",
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user