Add Kotlin API for Matcha-TTS models. (#1668)
This commit is contained in:
5
.github/workflows/jni.yaml
vendored
5
.github/workflows/jni.yaml
vendored
@@ -75,3 +75,8 @@ jobs:
|
|||||||
|
|
||||||
cd ./kotlin-api-examples
|
cd ./kotlin-api-examples
|
||||||
./run.sh
|
./run.sh
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: tts-files-${{ matrix.os }}
|
||||||
|
path: kotlin-api-examples/test-*.wav
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8
|
|||||||
sherpa-onnx-moonshine-base-en-int8
|
sherpa-onnx-moonshine-base-en-int8
|
||||||
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
|
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
|
||||||
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
|
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
|
||||||
|
matcha-icefall-zh-baker
|
||||||
|
|||||||
@@ -105,6 +105,16 @@ function testTts() {
|
|||||||
rm vits-piper-en_US-amy-low.tar.bz2
|
rm vits-piper-en_US-amy-low.tar.bz2
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
|
||||||
|
tar xvf matcha-icefall-zh-baker.tar.bz2
|
||||||
|
rm matcha-icefall-zh-baker.tar.bz2
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f ./hifigan_v2.onnx ]; then
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
|
||||||
|
fi
|
||||||
|
|
||||||
out_filename=test_tts.jar
|
out_filename=test_tts.jar
|
||||||
kotlinc-jvm -include-runtime -d $out_filename \
|
kotlinc-jvm -include-runtime -d $out_filename \
|
||||||
test_tts.kt \
|
test_tts.kt \
|
||||||
|
|||||||
@@ -1,10 +1,35 @@
|
|||||||
package com.k2fsa.sherpa.onnx
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
testTts()
|
testVits()
|
||||||
|
testMatcha()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun testTts() {
|
fun testMatcha() {
|
||||||
|
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
|
||||||
|
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
|
||||||
|
var config = OfflineTtsConfig(
|
||||||
|
model=OfflineTtsModelConfig(
|
||||||
|
matcha=OfflineTtsMatchaModelConfig(
|
||||||
|
acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx",
|
||||||
|
vocoder="./hifigan_v2.onnx",
|
||||||
|
tokens="./matcha-icefall-zh-baker/tokens.txt",
|
||||||
|
lexicon="./matcha-icefall-zh-baker/lexicon.txt",
|
||||||
|
dictDir="./matcha-icefall-zh-baker/dict",
|
||||||
|
),
|
||||||
|
numThreads=1,
|
||||||
|
debug=true,
|
||||||
|
),
|
||||||
|
ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst",
|
||||||
|
)
|
||||||
|
val tts = OfflineTts(config=config)
|
||||||
|
val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
|
||||||
|
audio.save(filename="test-zh.wav")
|
||||||
|
tts.release()
|
||||||
|
println("Saved to test-zh.wav")
|
||||||
|
}
|
||||||
|
|
||||||
|
fun testVits() {
|
||||||
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
|
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
|
||||||
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
|
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
|
||||||
var config = OfflineTtsConfig(
|
var config = OfflineTtsConfig(
|
||||||
|
|||||||
@@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
|
|||||||
auto p = new SherpaOnnxOnlinePunctuation;
|
auto p = new SherpaOnnxOnlinePunctuation;
|
||||||
try {
|
try {
|
||||||
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
|
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
|
||||||
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
|
punctuation_config.model.cnn_bilstm =
|
||||||
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
|
SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
|
||||||
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
|
punctuation_config.model.bpe_vocab =
|
||||||
|
SHERPA_ONNX_OR(config->model.bpe_vocab, "");
|
||||||
|
punctuation_config.model.num_threads =
|
||||||
|
SHERPA_ONNX_OR(config->model.num_threads, 1);
|
||||||
punctuation_config.model.debug = config->model.debug;
|
punctuation_config.model.debug = config->model.debug;
|
||||||
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
|
punctuation_config.model.provider =
|
||||||
|
SHERPA_ONNX_OR(config->model.provider, "cpu");
|
||||||
|
|
||||||
p->impl =
|
p->impl =
|
||||||
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
|
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
|
||||||
|
|||||||
@@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
|
|||||||
SherpaOnnxOnlinePunctuationModelConfig model;
|
SherpaOnnxOnlinePunctuationModelConfig model;
|
||||||
} SherpaOnnxOnlinePunctuationConfig;
|
} SherpaOnnxOnlinePunctuationConfig;
|
||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
|
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
|
||||||
|
SherpaOnnxOnlinePunctuation;
|
||||||
|
|
||||||
// Create an online punctuation processor. The user has to invoke
|
// Create an online punctuation processor. The user has to invoke
|
||||||
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
|
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
|
||||||
// to avoid memory leak
|
// to avoid memory leak
|
||||||
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
|
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
|
||||||
|
SherpaOnnxCreateOnlinePunctuation(
|
||||||
const SherpaOnnxOnlinePunctuationConfig *config);
|
const SherpaOnnxOnlinePunctuationConfig *config);
|
||||||
|
|
||||||
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
|
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class JiebaLexicon::Impl {
|
|||||||
|
|
||||||
this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
|
this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
|
||||||
|
|
||||||
if (w == "。" || w == "!" || w == "?" || w == ",") {
|
if (IsPunct(w)) {
|
||||||
ans.emplace_back(std::move(this_sentence));
|
ans.emplace_back(std::move(this_sentence));
|
||||||
this_sentence = {};
|
this_sentence = {};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
|
|||||||
jobject model = env->GetObjectField(config, fid);
|
jobject model = env->GetObjectField(config, fid);
|
||||||
jclass model_config_cls = env->GetObjectClass(model);
|
jclass model_config_cls = env->GetObjectClass(model);
|
||||||
|
|
||||||
|
// vits
|
||||||
fid = env->GetFieldID(model_config_cls, "vits",
|
fid = env->GetFieldID(model_config_cls, "vits",
|
||||||
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
|
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
|
||||||
jobject vits = env->GetObjectField(model, fid);
|
jobject vits = env->GetObjectField(model, fid);
|
||||||
@@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
|
|||||||
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
|
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
|
||||||
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
|
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
|
||||||
|
|
||||||
|
// matcha
|
||||||
|
fid = env->GetFieldID(model_config_cls, "matcha",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
|
||||||
|
jobject matcha = env->GetObjectField(model, fid);
|
||||||
|
jclass matcha_cls = env->GetObjectClass(matcha);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.acoustic_model = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.vocoder = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.lexicon = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.tokens = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.data_dir = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(matcha, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.matcha.dict_dir = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
|
||||||
|
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
|
||||||
|
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||||
ans.model.num_threads = env->GetIntField(model, fid);
|
ans.model.num_threads = env->GetIntField(model, fid);
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
|
|||||||
var lengthScale: Float = 1.0f,
|
var lengthScale: Float = 1.0f,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data class OfflineTtsMatchaModelConfig(
|
||||||
|
var acousticModel: String = "",
|
||||||
|
var vocoder: String = "",
|
||||||
|
var lexicon: String = "",
|
||||||
|
var tokens: String = "",
|
||||||
|
var dataDir: String = "",
|
||||||
|
var dictDir: String = "",
|
||||||
|
var noiseScale: Float = 1.0f,
|
||||||
|
var lengthScale: Float = 1.0f,
|
||||||
|
)
|
||||||
|
|
||||||
data class OfflineTtsModelConfig(
|
data class OfflineTtsModelConfig(
|
||||||
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
|
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
|
||||||
|
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
|
||||||
var numThreads: Int = 1,
|
var numThreads: Int = 1,
|
||||||
var debug: Boolean = false,
|
var debug: Boolean = false,
|
||||||
var provider: String = "cpu",
|
var provider: String = "cpu",
|
||||||
|
|||||||
Reference in New Issue
Block a user