Add Kotlin API for Matcha-TTS models. (#1668)

This commit is contained in:
Fangjun Kuang
2024-12-31 19:20:52 +08:00
committed by GitHub
parent 0a43e9c879
commit 3422b9388d
9 changed files with 117 additions and 9 deletions

View File

@@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
auto p = new SherpaOnnxOnlinePunctuation;
try {
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.cnn_bilstm =
SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab =
SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads =
SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.debug = config->model.debug;
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
punctuation_config.model.provider =
SHERPA_ONNX_OR(config->model.provider, "cpu");
p->impl =
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);

View File

@@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
SherpaOnnxOnlinePunctuationModelConfig model;
} SherpaOnnxOnlinePunctuationConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
SherpaOnnxOnlinePunctuation;
// Create an online punctuation processor. The user has to invoke
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config);
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()

View File

@@ -155,7 +155,7 @@ class JiebaLexicon::Impl {
this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
if (w == "" || w == "" || w == "" || w == "") {
if (IsPunct(w)) {
ans.emplace_back(std::move(this_sentence));
this_sentence = {};
}

View File

@@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
// vits
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
@@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
// matcha
fid = env->GetFieldID(model_config_cls, "matcha",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
jobject matcha = env->GetObjectField(model, fid);
jclass matcha_cls = env->GetObjectClass(matcha);
fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.acoustic_model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.vocoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.lexicon = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.data_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.dict_dir = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);

View File

@@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
var lengthScale: Float = 1.0f,
)
data class OfflineTtsMatchaModelConfig(
var acousticModel: String = "",
var vocoder: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "",
var noiseScale: Float = 1.0f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",