support reading rule FST for Android TTS (#410)

This commit is contained in:
Fangjun Kuang
2023-11-06 10:38:40 +08:00
committed by GitHub
parent 723e5265bb
commit 86baf43c6b
10 changed files with 143 additions and 25 deletions

View File

@@ -34,6 +34,11 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-android
- name: Display NDK HOME - name: Display NDK HOME
shell: bash shell: bash
run: | run: |
@@ -61,6 +66,10 @@ jobs:
- name: build APK - name: build APK
shell: bash shell: bash
run: | run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-tts.sh ./build-apk-tts.sh
@@ -70,12 +79,14 @@ jobs:
ls -lh ./apks/ ls -lh ./apks/
du -h -d1 . du -h -d1 .
# - uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v3
# with: if: false
# name: tts-apk with:
# path: ./apks/*.apk name: tts-apk
path: ./apks/*.apk
- name: Publish to huggingface - name: Publish to huggingface
if: true
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2 uses: nick-fields/retry@v2
@@ -92,7 +103,9 @@ jobs:
git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
cd huggingface cd huggingface
git fetch
git pull git pull
git merge -m "merge remote" --ff origin main
mkdir -p tts mkdir -p tts
cp -v ../apks/*.apk ./tts/ cp -v ../apks/*.apk ./tts/

View File

@@ -28,6 +28,12 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ${{ matrix.os }}-android
- name: Display NDK HOME - name: Display NDK HOME
shell: bash shell: bash
run: | run: |
@@ -37,6 +43,10 @@ jobs:
- name: build APK - name: build APK
shell: bash shell: bash
run: | run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-vad.sh ./build-apk-vad.sh
./build-apk-two-pass.sh ./build-apk-two-pass.sh

View File

@@ -101,12 +101,14 @@ class MainActivity : AppCompatActivity() {
fun initTts() { fun initTts() {
var modelDir :String? var modelDir :String?
var modelName :String? var modelName :String?
var ruleFsts: String?
// The purpose of such a design is to make the CI test easier // The purpose of such a design is to make the CI test easier
// Please see // Please see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py
modelDir = null modelDir = null
modelName = null modelName = null
ruleFsts = null
// Example 1: // Example 1:
// modelDir = "vits-vctk" // modelDir = "vits-vctk"
@@ -116,7 +118,12 @@ class MainActivity : AppCompatActivity() {
// modelDir = "vits-piper-en_US-lessac-medium" // modelDir = "vits-piper-en_US-lessac-medium"
// modelName = "en_US-lessac-medium.onnx" // modelName = "en_US-lessac-medium.onnx"
val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!)!! // Example 3:
// modelDir = "vits-zh-aishell3"
// modelName = "vits-aishell3.onnx"
// ruleFsts = "vits-zh-aishell3/rule.fst"
val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!, ruleFsts = ruleFsts ?: "")!!
tts = OfflineTts(assetManager = application.assets, config = config) tts = OfflineTts(assetManager = application.assets, config = config)
} }
} }

View File

@@ -21,6 +21,7 @@ data class OfflineTtsModelConfig(
data class OfflineTtsConfig( data class OfflineTtsConfig(
var model: OfflineTtsModelConfig, var model: OfflineTtsModelConfig,
var ruleFsts: String = "",
) )
class GeneratedAudio( class GeneratedAudio(
@@ -116,7 +117,7 @@ class OfflineTts(
// please refer to // please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html // https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models // to download models
fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? { fun getOfflineTtsConfig(modelDir: String, modelName: String, ruleFsts: String): OfflineTtsConfig? {
return OfflineTtsConfig( return OfflineTtsConfig(
model = OfflineTtsModelConfig( model = OfflineTtsModelConfig(
vits = OfflineTtsVitsModelConfig( vits = OfflineTtsVitsModelConfig(
@@ -125,8 +126,9 @@ fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig?
tokens = "$modelDir/tokens.txt" tokens = "$modelDir/tokens.txt"
), ),
numThreads = 2, numThreads = 2,
debug = false, debug = true,
provider = "cpu", provider = "cpu",
) ),
ruleFsts=ruleFsts,
) )
} }

View File

@@ -1,18 +1,18 @@
function(download_kaldifst) function(download_kaldifst)
include(FetchContent) include(FetchContent)
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.8.tar.gz") set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.9.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.8.tar.gz") set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.9.tar.gz")
set(kaldifst_HASH "SHA256=94613923568ef9a240ba1059b8b9dfe3082daad794934635d99e66248a6687b5") set(kaldifst_HASH "SHA256=8c653021491dca54c38ab659565edfab391418a79ae87099257863cd5664dd39")
# If you don't have access to the Internet, # If you don't have access to the Internet,
# please pre-download kaldifst # please pre-download kaldifst
set(possible_file_locations set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.8.tar.gz $ENV{HOME}/Downloads/kaldifst-1.7.9.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.8.tar.gz ${PROJECT_SOURCE_DIR}/kaldifst-1.7.9.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.8.tar.gz ${PROJECT_BINARY_DIR}/kaldifst-1.7.9.tar.gz
/tmp/kaldifst-1.7.8.tar.gz /tmp/kaldifst-1.7.9.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.8.tar.gz /star-fj/fangjun/download/github/kaldifst-1.7.9.tar.gz
) )
foreach(f IN LISTS possible_file_locations) foreach(f IN LISTS possible_file_locations)

View File

@@ -8,7 +8,7 @@
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake" # and some other files like the file "build/cmake/android.toolchain.cmake"
set -e set -ex
log() { log() {
# This function is from espnet # This function is from espnet
@@ -43,6 +43,7 @@ wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/$model_name
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/rule.fst 2>/dev/null || true
popd popd
# Now we are at the project root directory # Now we are at the project root directory
@@ -51,6 +52,11 @@ git checkout .
pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx
sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt
sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
{% if tts_model.rule_fsts %}
rule_fsts={{ tts_model.rule_fsts }}
sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./MainActivity.kt
{% endif %}
git diff git diff
popd popd

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional
import jinja2 import jinja2
from typing import List
import argparse
def get_args(): def get_args():
@@ -29,12 +29,65 @@ class TtsModel:
model_dir: str model_dir: str
model_name: str model_name: str
lang: str # en, zh, fr, de, etc. lang: str # en, zh, fr, de, etc.
rule_fsts: Optional[List[str]] = (None,)
def get_all_models() -> List[TtsModel]: def get_all_models() -> List[TtsModel]:
return [ return [
# Chinese
TtsModel( TtsModel(
model_dir="vits-zh-aishell3", model_name="vits-aishell3.onnx", lang="zh" model_dir="vits-zh-aishell3",
model_name="vits-aishell3.onnx",
lang="zh",
rule_fsts="vits-zh-aishell3/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-doom",
model_name="doom.onnx",
lang="zh",
rule_fsts="vits-zh-hf-doom/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-echo",
model_name="echo.onnx",
lang="zh",
rule_fsts="vits-zh-hf-echo/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-zenyatta",
model_name="zenyatta.onnx",
lang="zh",
rule_fsts="vits-zh-hf-zenyatta/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-abyssinvoker",
model_name="abyssinvoker.onnx",
lang="zh",
rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-keqing",
model_name="keqing.onnx",
lang="zh",
rule_fsts="vits-zh-hf-keqing/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-eula",
model_name="eula.onnx",
lang="zh",
rule_fsts="vits-zh-hf-eula/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-bronya",
model_name="bronya.onnx",
lang="zh",
rule_fsts="vits-zh-hf-bronya/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-theresa",
model_name="theresa.onnx",
lang="zh",
rule_fsts="vits-zh-hf-theresa/rule.fst",
), ),
# English (US) # English (US)
# fmt: off # fmt: off

View File

@@ -196,8 +196,14 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
std::vector<int64_t> ans; std::vector<int64_t> ans;
auto sil = token2id_.at("sil"); int32_t sil = -1;
auto eos = token2id_.at("eos"); int32_t eos = -1;
if (token2id_.count("sil")) {
sil = token2id_.at("sil");
eos = token2id_.at("eos");
} else {
sil = 0;
}
ans.push_back(sil); ans.push_back(sil);
@@ -216,7 +222,9 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
ans.insert(ans.end(), token_ids.begin(), token_ids.end()); ans.insert(ans.end(), token_ids.begin(), token_ids.end());
} }
ans.push_back(sil); ans.push_back(sil);
ans.push_back(eos); if (eos != -1) {
ans.push_back(eos);
}
return ans; return ans;
} }

View File

@@ -10,15 +10,17 @@
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h" #include "android/asset_manager.h"
#include "android/asset_manager_jni.h" #include "android/asset_manager_jni.h"
#endif #endif
#include "kaldifst/csrc/text-normalizer.h" #include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model.h" #include "sherpa-onnx/csrc/offline-tts-vits-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -52,7 +54,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
model_->Punctuations(), model_->Language(), config.model.debug, model_->Punctuations(), model_->Language(), config.model.debug,
model_->IsPiper()) { model_->IsPiper()) {
if (!config.rule_fsts.empty()) { if (!config.rule_fsts.empty()) {
SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android"); std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
tn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
}
} }
} }
#endif #endif

View File

@@ -566,6 +566,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
ans.model.provider = p; ans.model.provider = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
// for ruleFsts
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
return ans; return ans;
} }