Add C++ runtime for MeloTTS (#1138)
This commit is contained in:
10
.github/workflows/export-melo-tts-to-onnx.yaml
vendored
10
.github/workflows/export-melo-tts-to-onnx.yaml
vendored
@@ -63,10 +63,16 @@ jobs:
|
||||
echo "pwd: $PWD"
|
||||
ls -lh ../scripts/melo-tts
|
||||
|
||||
rm -rf ./
|
||||
|
||||
cp -v ../scripts/melo-tts/*.onnx .
|
||||
cp -v ../scripts/melo-tts/lexicon.txt .
|
||||
cp -v ../scripts/melo-tts/tokens.txt .
|
||||
cp -v ../scripts/melo-tts/README.md .
|
||||
|
||||
curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE
|
||||
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/new_heteronym.fst
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
|
||||
@@ -77,6 +83,10 @@ jobs:
|
||||
git lfs track "*.onnx"
|
||||
git add .
|
||||
|
||||
ls -lh
|
||||
|
||||
git status
|
||||
|
||||
git commit -m "add models"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
|
||||
|
||||
|
||||
10
.github/workflows/windows-x64-jni.yaml
vendored
10
.github/workflows/windows-x64-jni.yaml
vendored
@@ -39,10 +39,14 @@ jobs:
|
||||
cd build
|
||||
cmake \
|
||||
-A x64 \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_SHARED_LIBS=ON \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-D SHERPA_ONNX_ENABLE_JNI=ON \
|
||||
-D CMAKE_INSTALL_PREFIX=./install \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
|
||||
-DBUILD_ESPEAK_NG_EXE=OFF \
|
||||
-DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_BINARY=ON \
|
||||
..
|
||||
|
||||
- name: Build sherpa-onnx for windows
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
## 1.10.16
|
||||
|
||||
* Support zh-en TTS model from MeloTTS.
|
||||
|
||||
## 1.10.15
|
||||
|
||||
* Downgrade onnxruntime from v1.18.1 to v1.17.1
|
||||
|
||||
@@ -11,7 +11,7 @@ project(sherpa-onnx)
|
||||
# ./nodejs-addon-examples
|
||||
# ./dart-api-examples/
|
||||
# ./CHANGELOG.md
|
||||
set(SHERPA_ONNX_VERSION "1.10.15")
|
||||
set(SHERPA_ONNX_VERSION "1.10.16")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
|
||||
@@ -10,7 +10,7 @@ environment:
|
||||
|
||||
# Add regular dependencies here.
|
||||
dependencies:
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
path: ^1.9.0
|
||||
args: ^2.5.0
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ environment:
|
||||
|
||||
# Add regular dependencies here.
|
||||
dependencies:
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
path: ^1.9.0
|
||||
args: ^2.5.0
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ environment:
|
||||
|
||||
# Add regular dependencies here.
|
||||
dependencies:
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
path: ^1.9.0
|
||||
args: ^2.5.0
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ environment:
|
||||
sdk: ^3.4.0
|
||||
|
||||
dependencies:
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
path: ^1.9.0
|
||||
args: ^2.5.0
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description: >
|
||||
|
||||
publish_to: 'none'
|
||||
|
||||
version: 1.10.14
|
||||
version: 1.10.16
|
||||
|
||||
topics:
|
||||
- speech-recognition
|
||||
@@ -30,7 +30,7 @@ dependencies:
|
||||
record: ^5.1.0
|
||||
url_launcher: ^6.2.6
|
||||
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
# sherpa_onnx:
|
||||
# path: ../../flutter/sherpa_onnx
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description: >
|
||||
|
||||
publish_to: 'none' # Remove this line if you wish to publish to pub.dev
|
||||
|
||||
version: 1.0.0
|
||||
version: 1.10.16
|
||||
|
||||
environment:
|
||||
sdk: '>=3.4.0 <4.0.0'
|
||||
@@ -17,7 +17,7 @@ dependencies:
|
||||
cupertino_icons: ^1.0.6
|
||||
path_provider: ^2.1.3
|
||||
path: ^1.9.0
|
||||
sherpa_onnx: ^1.10.15
|
||||
sherpa_onnx: ^1.10.16
|
||||
url_launcher: ^6.2.6
|
||||
audioplayers: ^5.0.0
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ topics:
|
||||
- voice-activity-detection
|
||||
|
||||
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec
|
||||
version: 1.10.15
|
||||
version: 1.10.16
|
||||
|
||||
homepage: https://github.com/k2-fsa/sherpa-onnx
|
||||
|
||||
@@ -30,19 +30,19 @@ dependencies:
|
||||
flutter:
|
||||
sdk: flutter
|
||||
|
||||
sherpa_onnx_android: ^1.10.15
|
||||
sherpa_onnx_android: ^1.10.16
|
||||
# path: ../sherpa_onnx_android
|
||||
|
||||
sherpa_onnx_macos: ^1.10.15
|
||||
sherpa_onnx_macos: ^1.10.16
|
||||
# path: ../sherpa_onnx_macos
|
||||
|
||||
sherpa_onnx_linux: ^1.10.15
|
||||
sherpa_onnx_linux: ^1.10.16
|
||||
# path: ../sherpa_onnx_linux
|
||||
#
|
||||
sherpa_onnx_windows: ^1.10.15
|
||||
sherpa_onnx_windows: ^1.10.16
|
||||
# path: ../sherpa_onnx_windows
|
||||
|
||||
sherpa_onnx_ios: ^1.10.15
|
||||
sherpa_onnx_ios: ^1.10.16
|
||||
# sherpa_onnx_ios:
|
||||
# path: ../sherpa_onnx_ios
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c
|
||||
Pod::Spec.new do |s|
|
||||
s.name = 'sherpa_onnx_ios'
|
||||
s.version = '1.10.15'
|
||||
s.version = '1.10.16'
|
||||
s.summary = 'A new Flutter FFI plugin project.'
|
||||
s.description = <<-DESC
|
||||
A new Flutter FFI plugin project.
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#
|
||||
Pod::Spec.new do |s|
|
||||
s.name = 'sherpa_onnx_macos'
|
||||
s.version = '1.10.15'
|
||||
s.version = '1.10.16'
|
||||
s.summary = 'sherpa-onnx Flutter FFI plugin project.'
|
||||
s.description = <<-DESC
|
||||
sherpa-onnx Flutter FFI plugin project.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"dependencies": {
|
||||
"sherpa-onnx-node": "^1.10.15"
|
||||
"sherpa-onnx-node": "^1.10.16"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt
|
||||
git diff
|
||||
popd
|
||||
|
||||
if [[ $model_dir == vits-melo-tts-zh_en ]]; then
|
||||
lang=zh_en
|
||||
fi
|
||||
|
||||
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||
log "------------------------------------------------------------"
|
||||
log "build tts apk for $arch"
|
||||
|
||||
@@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
|
||||
git diff
|
||||
popd
|
||||
|
||||
if [[ $model_dir == vits-melo-tts-zh_en ]]; then
|
||||
lang=zh_en
|
||||
fi
|
||||
|
||||
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||
log "------------------------------------------------------------"
|
||||
log "build tts apk for $arch"
|
||||
|
||||
@@ -312,6 +312,11 @@ def get_vits_models() -> List[TtsModel]:
|
||||
model_name="vits-zh-hf-fanchen-wnj.onnx",
|
||||
lang="zh",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-melo-tts-zh_en",
|
||||
model_name="model.onnx",
|
||||
lang="zh",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-fanchen-C",
|
||||
model_name="vits-zh-hf-fanchen-C.onnx",
|
||||
@@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]:
|
||||
),
|
||||
]
|
||||
|
||||
rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"]
|
||||
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
|
||||
for m in chinese_models:
|
||||
s = [f"{m.model_dir}/{r}" for r in rule_fsts]
|
||||
if "vits-zh-hf" in m.model_dir or "sherpa-onnx-vits-zh-ll" == m.model_dir:
|
||||
if (
|
||||
"vits-zh-hf" in m.model_dir
|
||||
or "sherpa-onnx-vits-zh-ll" == m.model_dir
|
||||
or "melo-tts" in m.model_dir
|
||||
):
|
||||
s = s[:-1]
|
||||
m.dict_dir = m.model_dir + "/dict"
|
||||
else:
|
||||
m.rule_fars = f"{m.model_dir}/rule.far"
|
||||
|
||||
m.rule_fsts = ",".join(s)
|
||||
|
||||
if "vits-zh-hf" not in m.model_dir and "zh-ll" not in m.model_dir:
|
||||
m.rule_fars = f"{m.model_dir}/rule.far"
|
||||
|
||||
all_models = chinese_models + [
|
||||
TtsModel(
|
||||
model_dir="vits-cantonese-hf-xiaomaiiwn",
|
||||
|
||||
@@ -17,7 +17,7 @@ topics:
|
||||
- voice-activity-detection
|
||||
|
||||
# remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec
|
||||
version: 1.10.15
|
||||
version: 1.10.16
|
||||
|
||||
homepage: https://github.com/k2-fsa/sherpa-onnx
|
||||
|
||||
|
||||
@@ -6,9 +6,6 @@ from typing import List, Optional
|
||||
|
||||
import jinja2
|
||||
|
||||
# pip install iso639-lang
|
||||
from iso639 import Lang
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -37,13 +34,6 @@ class TtsModel:
|
||||
data_dir: Optional[str] = None
|
||||
dict_dir: Optional[str] = None
|
||||
is_char: bool = False
|
||||
lang_iso_639_3: str = ""
|
||||
|
||||
|
||||
def convert_lang_to_iso_639_3(models: List[TtsModel]):
|
||||
for m in models:
|
||||
if m.lang_iso_639_3 == "":
|
||||
m.lang_iso_639_3 = Lang(m.lang).pt3
|
||||
|
||||
|
||||
def get_coqui_models() -> List[TtsModel]:
|
||||
@@ -312,6 +302,11 @@ def get_vits_models() -> List[TtsModel]:
|
||||
model_name="vits-zh-hf-fanchen-wnj.onnx",
|
||||
lang="zh",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-melo-tts-zh_en",
|
||||
model_name="model.onnx",
|
||||
lang="zh_en",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="vits-zh-hf-fanchen-C",
|
||||
model_name="vits-zh-hf-fanchen-C.onnx",
|
||||
@@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]:
|
||||
model_name="vits-zh-hf-fanchen-unity.onnx",
|
||||
lang="zh",
|
||||
),
|
||||
TtsModel(
|
||||
model_dir="sherpa-onnx-vits-zh-ll",
|
||||
model_name="model.onnx",
|
||||
lang="zh",
|
||||
),
|
||||
]
|
||||
|
||||
rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"]
|
||||
rule_fsts = ["phone.fst", "date.fst", "number.fst"]
|
||||
for m in chinese_models:
|
||||
s = [f"{m.model_dir}/{r}" for r in rule_fsts]
|
||||
if "vits-zh-hf" in m.model_dir:
|
||||
if (
|
||||
"vits-zh-hf" in m.model_dir
|
||||
or "sherpa-onnx-vits-zh-ll" == m.model_dir
|
||||
or "melo-tts" in m.model_dir
|
||||
):
|
||||
s = s[:-1]
|
||||
m.dict_dir = m.model_dir + "/dict"
|
||||
else:
|
||||
m.rule_fars = f"{m.model_dir}/rule.far"
|
||||
|
||||
m.rule_fsts = ",".join(s)
|
||||
|
||||
if "vits-zh-hf" not in m.model_dir:
|
||||
m.rule_fars = f"{m.model_dir}/rule.far"
|
||||
|
||||
all_models = chinese_models + [
|
||||
TtsModel(
|
||||
model_dir="vits-cantonese-hf-xiaomaiiwn",
|
||||
model_name="vits-cantonese-hf-xiaomaiiwn.onnx",
|
||||
lang="cantonese",
|
||||
lang_iso_639_3="yue",
|
||||
rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst",
|
||||
),
|
||||
# English (US)
|
||||
@@ -374,7 +376,6 @@ def main():
|
||||
all_model_list += get_piper_models()
|
||||
all_model_list += get_mimic3_models()
|
||||
all_model_list += get_coqui_models()
|
||||
convert_lang_to_iso_639_3(all_model_list)
|
||||
|
||||
num_models = len(all_model_list)
|
||||
|
||||
|
||||
6
scripts/melo-tts/README.md
Normal file
6
scripts/melo-tts/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# Introduction
|
||||
|
||||
Models in this directory are converted from
|
||||
https://github.com/myshell-ai/MeloTTS
|
||||
|
||||
Note there is only a single female speaker in the model.
|
||||
@@ -8,7 +8,6 @@ from melo.text import language_id_map, language_tone_start_map
|
||||
from melo.text.chinese import pinyin_to_symbol_map
|
||||
from melo.text.english import eng_dict, refine_syllables
|
||||
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
|
||||
from melo.text.symbols import language_tone_start_map
|
||||
|
||||
for k, v in pinyin_to_symbol_map.items():
|
||||
if isinstance(v, list):
|
||||
@@ -82,6 +81,7 @@ def generate_tokens(symbol_list):
|
||||
def generate_lexicon():
|
||||
word_dict = pinyin_dict.pinyin_dict
|
||||
phrases = phrases_dict.phrases_dict
|
||||
eng_dict["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]]
|
||||
with open("lexicon.txt", "w", encoding="utf-8") as f:
|
||||
for word in eng_dict:
|
||||
phones, tones = refine_syllables(eng_dict[word])
|
||||
@@ -237,9 +237,11 @@ def main():
|
||||
meta_data = {
|
||||
"model_type": "melo-vits",
|
||||
"comment": "melo",
|
||||
"version": 2,
|
||||
"language": "Chinese + English",
|
||||
"add_blank": int(model.hps.data.add_blank),
|
||||
"n_speakers": 1,
|
||||
"jieba": 1,
|
||||
"sample_rate": model.hps.data.sampling_rate,
|
||||
"bert_dim": 1024,
|
||||
"ja_bert_dim": 768,
|
||||
|
||||
@@ -12,7 +12,7 @@ function install() {
|
||||
cd MeloTTS
|
||||
pip install -r ./requirements.txt
|
||||
|
||||
pip install soundfile onnx onnxruntime
|
||||
pip install soundfile onnx==1.15.0 onnxruntime==1.16.3
|
||||
|
||||
python3 -m unidic download
|
||||
popd
|
||||
|
||||
@@ -135,28 +135,11 @@ class OnnxModel:
|
||||
def main():
|
||||
lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
|
||||
|
||||
text = "永远相信,美好的事情即将发生。"
|
||||
text = "这是一个使用 next generation kaldi 的 text to speech 中英文例子. Thank you! 你觉得如何呢? are you ok? Fantastic! How about you?"
|
||||
s = jieba.cut(text, HMM=True)
|
||||
|
||||
phones, tones = lexicon.convert(s)
|
||||
|
||||
en_text = "how are you ?".split()
|
||||
|
||||
phones_en, tones_en = lexicon.convert(en_text)
|
||||
phones += [0]
|
||||
tones += [0]
|
||||
|
||||
phones += phones_en
|
||||
tones += tones_en
|
||||
|
||||
text = "多音字测试, 银行,行不行?长沙长大"
|
||||
s = jieba.cut(text, HMM=True)
|
||||
|
||||
phones2, tones2 = lexicon.convert(s)
|
||||
|
||||
phones += phones2
|
||||
tones += tones2
|
||||
|
||||
model = OnnxModel("./model.onnx")
|
||||
|
||||
if model.add_blank:
|
||||
|
||||
@@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
||||
|
||||
void SherpaOnnxOfflineRecognizerSetConfig(
|
||||
const SherpaOnnxOfflineRecognizer *recognizer,
|
||||
const SherpaOnnxOfflineRecognizerConfig *config){
|
||||
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
||||
convertConfig(config);
|
||||
recognizer->impl->SetConfig(recognizer_config);
|
||||
recognizer->impl->SetConfig(recognizer_config);
|
||||
}
|
||||
|
||||
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
|
||||
@@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
||||
pText[text.size()] = 0;
|
||||
r->text = pText;
|
||||
|
||||
//lang
|
||||
// lang
|
||||
const auto &lang = result.lang;
|
||||
char *c_lang = new char[lang.size() + 1];
|
||||
std::copy(lang.begin(), lang.end(), c_lang);
|
||||
@@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
|
||||
}
|
||||
delete[] r->matches;
|
||||
delete r;
|
||||
};
|
||||
}
|
||||
|
||||
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
|
||||
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
|
||||
|
||||
@@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
const char *text;
|
||||
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
//
|
||||
// It is NULL if the model does not support timestamps
|
||||
float *timestamps;
|
||||
@@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
*/
|
||||
const char *json;
|
||||
|
||||
//return recognized language
|
||||
// return recognized language
|
||||
const char *lang;
|
||||
|
||||
} SherpaOnnxOfflineRecognizerResult;
|
||||
|
||||
/// Get the result of the offline stream.
|
||||
|
||||
@@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND sources
|
||||
jieba-lexicon.cc
|
||||
lexicon.cc
|
||||
melo-tts-lexicon.cc
|
||||
offline-tts-character-frontend.cc
|
||||
offline-tts-frontend.cc
|
||||
offline-tts-impl.cc
|
||||
offline-tts-model-config.cc
|
||||
offline-tts-vits-model-config.cc
|
||||
|
||||
@@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) {
|
||||
std::vector<std::string> words;
|
||||
std::vector<cppjieba::Word> jiebawords;
|
||||
|
||||
std::string s = "他来到了网易杭研大厦";
|
||||
std::string s = "他来到了网易杭研大厦。How are you?";
|
||||
std::cout << s << std::endl;
|
||||
std::cout << "[demo] Cut With HMM" << std::endl;
|
||||
jieba.Cut(s, words, true);
|
||||
|
||||
@@ -17,6 +17,7 @@ namespace sherpa_onnx {
|
||||
|
||||
// implemented in ./lexicon.cc
|
||||
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
|
||||
|
||||
std::vector<int32_t> ConvertTokensToIds(
|
||||
const std::unordered_map<std::string, int32_t> &token2id,
|
||||
const std::vector<std::string> &tokens);
|
||||
@@ -53,8 +54,7 @@ class JiebaLexicon::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
const std::string &text) const {
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const {
|
||||
// see
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
|
||||
std::regex punct_re{":|、|;"};
|
||||
@@ -87,7 +87,7 @@ class JiebaLexicon::Impl {
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
int32_t blank = token2id_.at(" ");
|
||||
@@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon,
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
|
||||
debug)) {}
|
||||
|
||||
std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*unused_voice = ""*/) const {
|
||||
return impl_->ConvertTextToTokenIds(text);
|
||||
}
|
||||
|
||||
@@ -10,11 +10,6 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
|
||||
|
||||
@@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend {
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
JiebaLexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||
const std::string &tokens, const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text,
|
||||
const std::string &unused_voice = "") const override;
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
|
||||
switch (language_) {
|
||||
case Language::kChinese:
|
||||
@@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
@@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
int32_t blank = -1;
|
||||
@@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
@@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
|
||||
int32_t blank = token2id_.at(" ");
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
for (const auto &w : words) {
|
||||
|
||||
@@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend {
|
||||
const std::string &language, bool debug = false);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese(
|
||||
const std::string &text) const;
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIdsChinese(
|
||||
const std::string &text) const;
|
||||
|
||||
void InitLanguage(const std::string &lang);
|
||||
|
||||
266
sherpa-onnx/csrc/melo-tts-lexicon.cc
Normal file
266
sherpa-onnx/csrc/melo-tts-lexicon.cc
Normal file
@@ -0,0 +1,266 @@
|
||||
// sherpa-onnx/csrc/melo-tts-lexicon.cc
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <regex> // NOLINT
|
||||
#include <utility>
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// implemented in ./lexicon.cc
|
||||
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
|
||||
|
||||
std::vector<int32_t> ConvertTokensToIds(
|
||||
const std::unordered_map<std::string, int32_t> &token2id,
|
||||
const std::vector<std::string> &tokens);
|
||||
|
||||
class MeloTtsLexicon::Impl {
|
||||
public:
|
||||
Impl(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||
: meta_data_(meta_data), debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
|
||||
{
|
||||
std::ifstream is(tokens);
|
||||
InitTokens(is);
|
||||
}
|
||||
|
||||
{
|
||||
std::ifstream is(lexicon);
|
||||
InitLexicon(is);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
|
||||
std::string text = ToLowerCase(_text);
|
||||
// see
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
|
||||
std::regex punct_re{":|、|;"};
|
||||
std::string s = std::regex_replace(text, punct_re, ",");
|
||||
|
||||
std::regex punct_re2("。");
|
||||
s = std::regex_replace(s, punct_re2, ".");
|
||||
|
||||
std::regex punct_re3("?");
|
||||
s = std::regex_replace(s, punct_re3, "?");
|
||||
|
||||
std::regex punct_re4("!");
|
||||
s = std::regex_replace(s, punct_re4, "!");
|
||||
|
||||
std::vector<std::string> words;
|
||||
bool is_hmm = true;
|
||||
jieba_->Cut(text, words, is_hmm);
|
||||
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
|
||||
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
|
||||
|
||||
std::ostringstream os;
|
||||
std::string sep = "";
|
||||
for (const auto &w : words) {
|
||||
os << sep << w;
|
||||
sep = "_";
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ans;
|
||||
TokenIDs this_sentence;
|
||||
|
||||
int32_t blank = token2id_.at("_");
|
||||
for (const auto &w : words) {
|
||||
auto ids = ConvertWordToIds(w);
|
||||
if (ids.tokens.empty()) {
|
||||
SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
this_sentence.tokens.insert(this_sentence.tokens.end(),
|
||||
ids.tokens.begin(), ids.tokens.end());
|
||||
this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
|
||||
ids.tones.end());
|
||||
|
||||
if (w == "." || w == "!" || w == "?" || w == ",") {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
this_sentence = {};
|
||||
}
|
||||
} // for (const auto &w : words)
|
||||
|
||||
if (!this_sentence.tokens.empty()) {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
TokenIDs ConvertWordToIds(const std::string &w) const {
|
||||
if (word2ids_.count(w)) {
|
||||
return word2ids_.at(w);
|
||||
}
|
||||
|
||||
if (token2id_.count(w)) {
|
||||
return {{token2id_.at(w)}, {0}};
|
||||
}
|
||||
|
||||
TokenIDs ans;
|
||||
|
||||
std::vector<std::string> words = SplitUtf8(w);
|
||||
for (const auto &word : words) {
|
||||
if (word2ids_.count(word)) {
|
||||
auto ids = ConvertWordToIds(word);
|
||||
ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
|
||||
ids.tokens.end());
|
||||
ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void InitTokens(std::istream &is) {
|
||||
token2id_ = ReadTokens(is);
|
||||
token2id_[" "] = token2id_["_"];
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> puncts = {
|
||||
{",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}};
|
||||
|
||||
for (const auto &p : puncts) {
|
||||
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
|
||||
token2id_[p.second] = token2id_[p.first];
|
||||
}
|
||||
|
||||
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
|
||||
token2id_[p.first] = token2id_[p.second];
|
||||
}
|
||||
}
|
||||
|
||||
if (!token2id_.count("、") && token2id_.count(",")) {
|
||||
token2id_["、"] = token2id_[","];
|
||||
}
|
||||
}
|
||||
|
||||
void InitLexicon(std::istream &is) {
|
||||
std::string word;
|
||||
std::vector<std::string> token_list;
|
||||
|
||||
std::vector<std::string> phone_list;
|
||||
std::vector<int64_t> tone_list;
|
||||
|
||||
std::string line;
|
||||
std::string phone;
|
||||
int32_t line_num = 0;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
++line_num;
|
||||
|
||||
std::istringstream iss(line);
|
||||
|
||||
token_list.clear();
|
||||
phone_list.clear();
|
||||
tone_list.clear();
|
||||
|
||||
iss >> word;
|
||||
ToLowerCase(&word);
|
||||
|
||||
if (word2ids_.count(word)) {
|
||||
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
|
||||
word.c_str(), line_num, line.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
while (iss >> phone) {
|
||||
token_list.push_back(std::move(phone));
|
||||
}
|
||||
|
||||
if ((token_list.size() & 1) != 0) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t num_phones = token_list.size() / 2;
|
||||
phone_list.reserve(num_phones);
|
||||
tone_list.reserve(num_phones);
|
||||
|
||||
for (int32_t i = 0; i != num_phones; ++i) {
|
||||
phone_list.push_back(std::move(token_list[i]));
|
||||
tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
|
||||
if (tone_list.back() < 0 || tone_list.back() > 50) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
|
||||
if (ids.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ids.size() != num_phones) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> ids64{ids.begin(), ids.end()};
|
||||
|
||||
word2ids_.insert(
|
||||
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
|
||||
}
|
||||
|
||||
word2ids_["呣"] = word2ids_["母"];
|
||||
word2ids_["嗯"] = word2ids_["恩"];
|
||||
}
|
||||
|
||||
private:
|
||||
// lexicon.txt is saved in word2ids_
|
||||
std::unordered_map<std::string, TokenIDs> word2ids_;
|
||||
|
||||
// tokens.txt is saved in token2id_
|
||||
std::unordered_map<std::string, int32_t> token2id_;
|
||||
|
||||
OfflineTtsVitsModelMetaData meta_data_;
|
||||
|
||||
std::unique_ptr<cppjieba::Jieba> jieba_;
|
||||
bool debug_ = false;
|
||||
};
|
||||
|
||||
MeloTtsLexicon::~MeloTtsLexicon() = default;
|
||||
|
||||
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
|
||||
const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data,
|
||||
bool debug)
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
|
||||
debug)) {}
|
||||
|
||||
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*unused_voice = ""*/) const {
|
||||
return impl_->ConvertTextToTokenIds(text);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
36
sherpa-onnx/csrc/melo-tts-lexicon.h
Normal file
36
sherpa-onnx/csrc/melo-tts-lexicon.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// sherpa-onnx/csrc/melo-tts-lexicon.h
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class MeloTtsLexicon : public OfflineTtsFrontend {
|
||||
public:
|
||||
~MeloTtsLexicon() override;
|
||||
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text,
|
||||
const std::string &unused_voice = "") const override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
@@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
|
||||
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
const std::string &_text, const std::string & /*voice = ""*/) const {
|
||||
// see
|
||||
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
|
||||
@@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
|
||||
std::u32string s = conv.from_bytes(text);
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
|
||||
std::vector<int64_t> this_sentence;
|
||||
if (add_blank) {
|
||||
|
||||
@@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
|
||||
* If a frontend does not support splitting the text into
|
||||
* sentences, the resulting vector contains only one subvector.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
|
||||
34
sherpa-onnx/csrc/offline-tts-frontend.cc
Normal file
34
sherpa-onnx/csrc/offline-tts-frontend.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-tts-frontend.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string TokenIDs::ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "TokenIDs(";
|
||||
os << "tokens=[";
|
||||
std::string sep;
|
||||
for (auto i : tokens) {
|
||||
os << sep << i;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "tones=[";
|
||||
sep = {};
|
||||
for (auto i : tones) {
|
||||
os << sep << i;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -8,8 +8,28 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct TokenIDs {
|
||||
TokenIDs() = default;
|
||||
|
||||
/*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT
|
||||
: tokens{tokens} {}
|
||||
|
||||
TokenIDs(const std::vector<int64_t> &tokens,
|
||||
const std::vector<int64_t> &tones)
|
||||
: tokens{tokens}, tones{tones} {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
// Used only in MeloTTS
|
||||
std::vector<int64_t> tones;
|
||||
};
|
||||
|
||||
class OfflineTtsFrontend {
|
||||
public:
|
||||
virtual ~OfflineTtsFrontend() = default;
|
||||
@@ -26,7 +46,7 @@ class OfflineTtsFrontend {
|
||||
* If a frontend does not support splitting the text into sentences,
|
||||
* the resulting vector contains only one subvector.
|
||||
*/
|
||||
virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
virtual std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/csrc/jieba-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x =
|
||||
std::vector<TokenIDs> token_ids =
|
||||
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
|
||||
|
||||
if (x.empty() || (x.size() == 1 && x[0].empty())) {
|
||||
if (token_ids.empty() ||
|
||||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x;
|
||||
std::vector<std::vector<int64_t>> tones;
|
||||
|
||||
x.reserve(token_ids.size());
|
||||
|
||||
for (auto &i : token_ids) {
|
||||
x.push_back(std::move(i.tokens));
|
||||
}
|
||||
|
||||
if (!token_ids[0].tones.empty()) {
|
||||
tones.reserve(token_ids.size());
|
||||
for (auto &i : token_ids) {
|
||||
tones.push_back(std::move(i.tones));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): add blank inside the frontend, not here
|
||||
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
|
||||
meta_data.frontend != "characters") {
|
||||
for (auto &k : x) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
|
||||
for (auto &k : tones) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t x_size = static_cast<int32_t>(x.size());
|
||||
|
||||
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
||||
auto ans = Process(x, sid, speed);
|
||||
auto ans = Process(x, tones, sid, speed);
|
||||
if (callback) {
|
||||
callback(ans.samples.data(), ans.samples.size(), 1.0);
|
||||
}
|
||||
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
|
||||
// the input text is too long, we process sentences within it in batches
|
||||
// to avoid OOM. Batch size is config_.max_num_sentences
|
||||
std::vector<std::vector<int64_t>> batch;
|
||||
std::vector<std::vector<int64_t>> batch_x;
|
||||
std::vector<std::vector<int64_t>> batch_tones;
|
||||
|
||||
int32_t batch_size = config_.max_num_sentences;
|
||||
batch.reserve(config_.max_num_sentences);
|
||||
batch_x.reserve(config_.max_num_sentences);
|
||||
batch_tones.reserve(config_.max_num_sentences);
|
||||
int32_t num_batches = x_size / batch_size;
|
||||
|
||||
if (config_.model.debug) {
|
||||
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
int32_t k = 0;
|
||||
|
||||
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
for (int32_t i = 0; i != batch_size; ++i, ++k) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
}
|
||||
|
||||
auto audio = Process(batch, sid, speed);
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
while (k < static_cast<int32_t>(x.size()) && should_continue) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
|
||||
++k;
|
||||
}
|
||||
|
||||
if (!batch.empty()) {
|
||||
auto audio = Process(batch, sid, speed);
|
||||
if (!batch_x.empty()) {
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
if (meta_data.frontend == "characters") {
|
||||
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
|
||||
config_.model.vits.tokens, meta_data);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
|
||||
meta_data.is_melo_tts) {
|
||||
frontend_ = std::make_unique<MeloTtsLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
config_.model.vits.dict_dir, model_->GetMetaData(),
|
||||
config_.model.debug);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
|
||||
frontend_ = std::make_unique<JiebaLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
|
||||
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
|
||||
const std::vector<std::vector<int64_t>> &tones,
|
||||
int32_t sid, float speed) const {
|
||||
int32_t num_tokens = 0;
|
||||
for (const auto &k : tokens) {
|
||||
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
x.insert(x.end(), k.begin(), k.end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> tone_list;
|
||||
if (!tones.empty()) {
|
||||
tone_list.reserve(num_tokens);
|
||||
for (const auto &k : tones) {
|
||||
tone_list.insert(tone_list.end(), k.begin(), k.end());
|
||||
}
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
Ort::Value tones_tensor{nullptr};
|
||||
if (!tones.empty()) {
|
||||
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
|
||||
tone_list.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
}
|
||||
|
||||
Ort::Value audio{nullptr};
|
||||
if (tones.empty()) {
|
||||
audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
} else {
|
||||
audio =
|
||||
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
|
||||
}
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData {
|
||||
bool is_piper = false;
|
||||
bool is_coqui = false;
|
||||
bool is_icefall = false;
|
||||
bool is_melo_tts = false;
|
||||
|
||||
// for Chinese TTS models from
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning
|
||||
@@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData {
|
||||
int32_t use_eos_bos = 0;
|
||||
int32_t pad_id = 0;
|
||||
|
||||
// for melo tts
|
||||
int32_t speaker_id = 0;
|
||||
int32_t version = 0;
|
||||
|
||||
std::string punctuations;
|
||||
std::string language;
|
||||
std::string voice;
|
||||
|
||||
@@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl {
|
||||
return RunVits(std::move(x), sid, speed);
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
|
||||
// For MeloTTS, we hardcode sid to the one contained in the meta data
|
||||
sid = meta_data_.speaker_id;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (x_shape[0] != 1) {
|
||||
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
|
||||
static_cast<int32_t>(x_shape[0]));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int64_t len = x_shape[1];
|
||||
int64_t len_shape = 1;
|
||||
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
int64_t scale_shape = 1;
|
||||
float noise_scale = config_.vits.noise_scale;
|
||||
float length_scale = config_.vits.length_scale;
|
||||
float noise_scale_w = config_.vits.noise_scale_w;
|
||||
|
||||
if (speed != 1 && speed > 0) {
|
||||
length_scale = 1. / speed;
|
||||
}
|
||||
|
||||
Ort::Value noise_scale_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &length_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &noise_scale_w, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value sid_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(7);
|
||||
inputs.push_back(std::move(x));
|
||||
inputs.push_back(std::move(x_length));
|
||||
inputs.push_back(std::move(tones));
|
||||
inputs.push_back(std::move(sid_tensor));
|
||||
inputs.push_back(std::move(noise_scale_tensor));
|
||||
inputs.push_back(std::move(length_scale_tensor));
|
||||
inputs.push_back(std::move(noise_scale_w_tensor));
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
|
||||
|
||||
private:
|
||||
@@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
|
||||
0);
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id",
|
||||
0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0);
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
|
||||
"punctuation", "");
|
||||
@@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl {
|
||||
if (comment.find("icefall") != std::string::npos) {
|
||||
meta_data_.is_icefall = true;
|
||||
}
|
||||
|
||||
if (comment.find("melo") != std::string::npos) {
|
||||
meta_data_.is_melo_tts = true;
|
||||
int32_t expected_version = 2;
|
||||
if (meta_data_.version < expected_version) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please download the latest MeloTTS model and retry. Current "
|
||||
"version: %d. Expected version: %d",
|
||||
meta_data_.version, expected_version);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// NOTE(fangjun):
|
||||
// version 0 is the first version
|
||||
// version 2: add jieba=1 to the metadata
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
|
||||
@@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
|
||||
return impl_->Run(std::move(x), sid, speed);
|
||||
}
|
||||
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
|
||||
int64_t sid /*= 0*/,
|
||||
float speed /*= 1.0*/) {
|
||||
return impl_->Run(std::move(x), std::move(tones), sid, speed);
|
||||
}
|
||||
|
||||
const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
@@ -40,6 +40,10 @@ class OfflineTtsVitsModel {
|
||||
*/
|
||||
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
|
||||
|
||||
// This is for MeloTTS
|
||||
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
|
||||
float speed = 1.0);
|
||||
|
||||
const OfflineTtsVitsModelMetaData &GetMetaData() const;
|
||||
|
||||
private:
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
@@ -36,7 +36,6 @@ class OfflineWhisperDecoder {
|
||||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||
|
||||
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
|
||||
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
|
||||
void OfflineWhisperGreedySearchDecoder::SetConfig(
|
||||
const OfflineWhisperModelConfig &config) {
|
||||
config_ = config;
|
||||
}
|
||||
|
||||
@@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
|
||||
const auto &id2lang = model_->GetID2Lang();
|
||||
if (id2lang.count(initial_tokens[1])) {
|
||||
ans[0].lang = id2lang.at(initial_tokens[1]);
|
||||
ans[0].lang = id2lang.at(initial_tokens[1]);
|
||||
} else {
|
||||
ans[0].lang = "";
|
||||
ans[0].lang = "";
|
||||
}
|
||||
|
||||
ans[0].tokens = std::move(predicted_tokens);
|
||||
|
||||
@@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print1D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
const T *d = v->GetTensorData<T>();
|
||||
std::ostringstream os;
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
fprintf(stderr, "%.3f ", d[i]);
|
||||
os << *d << " ";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
os << "\n";
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
template void Print1D<int64_t>(Ort::Value *v);
|
||||
template void Print1D<float>(Ort::Value *v);
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print2D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
|
||||
Ort::Value View(Ort::Value *v);
|
||||
|
||||
// Print a 1-D tensor to stderr
|
||||
template <typename T = float>
|
||||
void Print1D(Ort::Value *v);
|
||||
|
||||
// Print a 2-D tensor to stderr
|
||||
|
||||
@@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice /*= ""*/) const {
|
||||
piper::eSpeakPhonemeConfig config;
|
||||
|
||||
@@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
piper::phonemize_eSpeak(text, config, phonemes);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
|
||||
std::vector<int64_t> phoneme_ids;
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
|
||||
@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
const std::string &provider_str,
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
int32_t num_threads, const std::string &provider_str,
|
||||
const ProviderConfig *provider_config = nullptr) {
|
||||
Provider p = StringToProvider(provider_str);
|
||||
|
||||
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
}
|
||||
case Provider::kTRT: {
|
||||
if (provider_config == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
exit(1);
|
||||
}
|
||||
auto trt_config = provider_config->trt_config;
|
||||
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
std::to_string(trt_config.trt_max_partition_iterations);
|
||||
auto trt_min_subgraph_size =
|
||||
std::to_string(trt_config.trt_min_subgraph_size);
|
||||
auto trt_fp16_enable =
|
||||
std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_detailed_build_log =
|
||||
std::to_string(trt_config.trt_detailed_build_log);
|
||||
auto trt_engine_cache_enable =
|
||||
std::to_string(trt_config.trt_engine_cache_enable);
|
||||
auto trt_timing_cache_enable =
|
||||
std::to_string(trt_config.trt_timing_cache_enable);
|
||||
auto trt_dump_subgraphs =
|
||||
std::to_string(trt_config.trt_dump_subgraphs);
|
||||
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
|
||||
std::vector<TrtPairs> trt_options = {
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
|
||||
};
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations",
|
||||
trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
|
||||
// ToDo : Trt configs
|
||||
// "trt_int8_enable"
|
||||
// "trt_int8_use_native_calibration_table"
|
||||
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
if (provider_config != nullptr) {
|
||||
options.device_id = provider_config->device;
|
||||
options.cudnn_conv_algo_search =
|
||||
OrtCudnnConvAlgoSearch(provider_config->cuda_config
|
||||
.cudnn_conv_algo_search);
|
||||
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
|
||||
provider_config->cuda_config.cudnn_conv_algo_search);
|
||||
} else {
|
||||
options.device_id = 0;
|
||||
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
|
||||
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type) {
|
||||
const std::string &model_type) {
|
||||
/*
|
||||
Transducer models : Only encoder will run with tensorrt,
|
||||
decoder and joiner will run with cuda
|
||||
*/
|
||||
if(config.provider_config.provider == "trt" &&
|
||||
if (config.provider_config.provider == "trt" &&
|
||||
(model_type == "decoder" || model_type == "joiner")) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
"cuda", &config.provider_config);
|
||||
return GetSessionOptionsImpl(config.num_threads, "cuda",
|
||||
&config.provider_config);
|
||||
}
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_SESSION_H_
|
||||
#define SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
@@ -25,7 +27,7 @@ namespace sherpa_onnx {
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type);
|
||||
const std::string &model_type);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(UTF8, Case1) {
|
||||
std::string hello = "你好, 早上好!世界. hello!。Hallo";
|
||||
std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?";
|
||||
std::vector<std::string> ss = SplitUtf8(hello);
|
||||
for (const auto &s : ss) {
|
||||
std::cout << s << "\n";
|
||||
|
||||
Reference in New Issue
Block a user