Add Koltin and Java API for Kokoro TTS models (#1728)

This commit is contained in:
Fangjun Kuang
2025-01-17 17:36:13 +08:00
committed by GitHub
parent 3a1de0bfc1
commit 99cef4198b
18 changed files with 548 additions and 39 deletions

View File

@@ -35,6 +35,7 @@ java_files += OfflineRecognizerResult.java
java_files += OfflineStream.java
java_files += OfflineRecognizer.java
java_files += OfflineTtsKokoroModelConfig.java
java_files += OfflineTtsMatchaModelConfig.java
java_files += OfflineTtsVitsModelConfig.java
java_files += OfflineTtsModelConfig.java

View File

@@ -0,0 +1,80 @@
// Copyright 2025 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineTtsKokoroModelConfig {
private final String model;
private final String voices;
private final String tokens;
private final String dataDir;
private final float lengthScale;
private OfflineTtsKokoroModelConfig(Builder builder) {
this.model = builder.model;
this.voices = builder.voices;
this.tokens = builder.tokens;
this.dataDir = builder.dataDir;
this.lengthScale = builder.lengthScale;
}
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public String getVoices() {
return voices;
}
public String getTokens() {
return tokens;
}
public String getDataDir() {
return dataDir;
}
public float getLengthScale() {
return lengthScale;
}
public static class Builder {
private String model = "";
private String voices = "";
private String tokens = "";
private String dataDir = "";
private float lengthScale = 1.0f;
public OfflineTtsKokoroModelConfig build() {
return new OfflineTtsKokoroModelConfig(this);
}
public Builder setModel(String model) {
this.model = model;
return this;
}
public Builder setVoices(String voices) {
this.voices = voices;
return this;
}
public Builder setTokens(String tokens) {
this.tokens = tokens;
return this;
}
public Builder setDataDir(String dataDir) {
this.dataDir = dataDir;
return this;
}
public Builder setLengthScale(float lengthScale) {
this.lengthScale = lengthScale;
return this;
}
}
}

View File

@@ -5,6 +5,7 @@ package com.k2fsa.sherpa.onnx;
public class OfflineTtsModelConfig {
private final OfflineTtsVitsModelConfig vits;
private final OfflineTtsMatchaModelConfig matcha;
private final OfflineTtsKokoroModelConfig kokoro;
private final int numThreads;
private final boolean debug;
private final String provider;
@@ -12,6 +13,7 @@ public class OfflineTtsModelConfig {
private OfflineTtsModelConfig(Builder builder) {
this.vits = builder.vits;
this.matcha = builder.matcha;
this.kokoro = builder.kokoro;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
@@ -29,9 +31,14 @@ public class OfflineTtsModelConfig {
return matcha;
}
public OfflineTtsKokoroModelConfig getKokoro() {
return kokoro;
}
public static class Builder {
private OfflineTtsVitsModelConfig vits = OfflineTtsVitsModelConfig.builder().build();
private OfflineTtsMatchaModelConfig matcha = OfflineTtsMatchaModelConfig.builder().build();
private OfflineTtsKokoroModelConfig kokoro = OfflineTtsKokoroModelConfig.builder().build();
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
@@ -50,6 +57,11 @@ public class OfflineTtsModelConfig {
return this;
}
public Builder setKokoro(OfflineTtsKokoroModelConfig kokoro) {
this.kokoro = kokoro;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;

View File

@@ -113,6 +113,39 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
// kokoro
fid = env->GetFieldID(model_config_cls, "kokoro",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;");
jobject kokoro = env->GetObjectField(model, fid);
jclass kokoro_cls = env->GetObjectClass(kokoro);
fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.voices = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(kokoro, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.kokoro.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(kokoro_cls, "lengthScale", "F");
ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
@@ -273,8 +306,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
return env->CallIntMethod(should_continue, int_value_mid);
};
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(
p_text, sid, speed, callback_wrapper);
auto tts = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr);
auto audio = tts->Generate(p_text, sid, speed, callback_wrapper);
jfloatArray samples_arr = env->NewFloatArray(audio.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(),

View File

@@ -25,9 +25,18 @@ data class OfflineTtsMatchaModelConfig(
var lengthScale: Float = 1.0f,
)
data class OfflineTtsKokoroModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lengthScale: Float = 1.0f,
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
@@ -176,12 +185,32 @@ fun getOfflineTtsConfig(
modelName: String, // for VITS
acousticModelName: String, // for Matcha
vocoder: String, // for Matcha
voices: String, // for Kokoro
lexicon: String,
dataDir: String,
dictDir: String,
ruleFsts: String,
ruleFars: String
ruleFars: String,
numThreads: Int? = null
): OfflineTtsConfig {
// For Matcha TTS, please set
// acousticModelName, vocoder
// For Kokoro TTS, please set
// modelName, voices
// For VITS, please set
// modelName
val numberOfThreads = if (numThreads != null) {
numThreads
} else if (voices.isNotEmpty()) {
// for Kokoro TTS models, we use more threads
4
} else {
2
}
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
throw IllegalArgumentException("Please specify a TTS model")
}
@@ -193,7 +222,8 @@ fun getOfflineTtsConfig(
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
}
val vits = if (modelName.isNotEmpty()) {
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
OfflineTtsVitsModelConfig(
model = "$modelDir/$modelName",
lexicon = "$modelDir/$lexicon",
@@ -218,11 +248,23 @@ fun getOfflineTtsConfig(
OfflineTtsMatchaModelConfig()
}
val kokoro = if (voices.isNotEmpty()) {
OfflineTtsKokoroModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsKokoroModelConfig()
}
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = vits,
matcha = matcha,
numThreads = 2,
kokoro = kokoro,
numThreads = numberOfThreads,
debug = true,
provider = "cpu",
),