Support Chinese vits models (#368)
This commit is contained in:
44
.github/scripts/test-python.sh
vendored
44
.github/scripts/test-python.sh
vendored
@@ -9,6 +9,10 @@ log() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log "Offline TTS test"
|
log "Offline TTS test"
|
||||||
|
# test waves are saved in ./tts
|
||||||
|
mkdir ./tts
|
||||||
|
|
||||||
|
log "vits-ljs test"
|
||||||
|
|
||||||
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
|
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx
|
||||||
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
|
wget -qq https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt
|
||||||
@@ -18,14 +22,48 @@ python3 ./python-api-examples/offline-tts.py \
|
|||||||
--vits-model=./vits-ljs.onnx \
|
--vits-model=./vits-ljs.onnx \
|
||||||
--vits-lexicon=./lexicon.txt \
|
--vits-lexicon=./lexicon.txt \
|
||||||
--vits-tokens=./tokens.txt \
|
--vits-tokens=./tokens.txt \
|
||||||
--output-filename=./tts.wav \
|
--output-filename=./tts/vits-ljs.wav \
|
||||||
'liliana, the most beautiful and lovely assistant of our team!'
|
'liliana, the most beautiful and lovely assistant of our team!'
|
||||||
|
|
||||||
ls -lh ./tts.wav
|
ls -lh ./tts
|
||||||
file ./tts.wav
|
|
||||||
|
|
||||||
rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt
|
rm -v vits-ljs.onnx ./lexicon.txt ./tokens.txt
|
||||||
|
|
||||||
|
log "vits-vctk test"
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vits-vctk.onnx
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/lexicon.txt
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-vctk/resolve/main/tokens.txt
|
||||||
|
|
||||||
|
for sid in 0 10 90; do
|
||||||
|
python3 ./python-api-examples/offline-tts.py \
|
||||||
|
--vits-model=./vits-vctk.onnx \
|
||||||
|
--vits-lexicon=./lexicon.txt \
|
||||||
|
--vits-tokens=./tokens.txt \
|
||||||
|
--sid=$sid \
|
||||||
|
--output-filename=./tts/vits-vctk-${sid}.wav \
|
||||||
|
'liliana, the most beautiful and lovely assistant of our team!'
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -v vits-vctk.onnx ./lexicon.txt ./tokens.txt
|
||||||
|
|
||||||
|
log "vits-zh-aishell3"
|
||||||
|
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/vits-aishell3.onnx
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/lexicon.txt
|
||||||
|
wget -qq https://huggingface.co/csukuangfj/vits-zh-aishell3/resolve/main/tokens.txt
|
||||||
|
|
||||||
|
for sid in 0 10 90; do
|
||||||
|
python3 ./python-api-examples/offline-tts.py \
|
||||||
|
--vits-model=./vits-aishell3.onnx \
|
||||||
|
--vits-lexicon=./lexicon.txt \
|
||||||
|
--vits-tokens=./tokens.txt \
|
||||||
|
--sid=$sid \
|
||||||
|
--output-filename=./tts/vits-aishell3-${sid}.wav \
|
||||||
|
'林美丽最美丽'
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -v vits-aishell3.onnx ./lexicon.txt ./tokens.txt
|
||||||
|
|
||||||
mkdir -p /tmp/icefall-models
|
mkdir -p /tmp/icefall-models
|
||||||
dir=/tmp/icefall-models
|
dir=/tmp/icefall-models
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/run-python-test.yaml
vendored
2
.github/workflows/run-python-test.yaml
vendored
@@ -69,4 +69,4 @@ jobs:
|
|||||||
- uses: actions/upload-artifact@v3
|
- uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: tts-generated-test-files
|
name: tts-generated-test-files
|
||||||
path: tts.wav
|
path: tts
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.8.1")
|
set(SHERPA_ONNX_VERSION "1.8.2")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
@@ -175,6 +175,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
|||||||
include(asio)
|
include(asio)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
include(utfcpp)
|
||||||
|
|
||||||
add_subdirectory(sherpa-onnx)
|
add_subdirectory(sherpa-onnx)
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_C_API)
|
if(SHERPA_ONNX_ENABLE_C_API)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ function(download_kaldi_decoder)
|
|||||||
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
|
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
|
||||||
|
|
||||||
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
function(download_kaldi_native_fbank)
|
function(download_kaldi_native_fbank)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
|
||||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
|
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
|
||||||
set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")
|
set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
|||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-native-fbank
|
# please pre-download kaldi-native-fbank
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
|
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
|
||||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
|
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
|
||||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
|
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
|
||||||
/tmp/kaldi-native-fbank-1.18.1.tar.gz
|
/tmp/kaldi-native-fbank-1.18.5.tar.gz
|
||||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
|
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
45
cmake/utfcpp.cmake
Normal file
45
cmake/utfcpp.cmake
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
function(download_utfcpp)
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(utfcpp_URL "https://github.com/nemtrif/utfcpp/archive/refs/tags/v3.2.5.tar.gz")
|
||||||
|
set(utfcpp_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/utfcpp-3.2.5.tar.gz")
|
||||||
|
set(utfcpp_HASH "SHA256=14fd1b3c466814cb4c40771b7f207b61d2c7a0aa6a5e620ca05c00df27f25afd")
|
||||||
|
|
||||||
|
# If you don't have access to the Internet,
|
||||||
|
# please pre-download utfcpp
|
||||||
|
set(possible_file_locations
|
||||||
|
$ENV{HOME}/Downloads/utfcpp-3.2.5.tar.gz
|
||||||
|
${PROJECT_SOURCE_DIR}/utfcpp-3.2.5.tar.gz
|
||||||
|
${PROJECT_BINARY_DIR}/utfcpp-3.2.5.tar.gz
|
||||||
|
/tmp/utfcpp-3.2.5.tar.gz
|
||||||
|
/star-fj/fangjun/download/github/utfcpp-3.2.5.tar.gz
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
if(EXISTS ${f})
|
||||||
|
set(utfcpp_URL "${f}")
|
||||||
|
file(TO_CMAKE_PATH "${utfcpp_URL}" utfcpp_URL)
|
||||||
|
message(STATUS "Found local downloaded utfcpp: ${utfcpp_URL}")
|
||||||
|
set(utfcpp_URL2)
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
FetchContent_Declare(utfcpp
|
||||||
|
URL
|
||||||
|
${utfcpp_URL}
|
||||||
|
${utfcpp_URL2}
|
||||||
|
URL_HASH ${utfcpp_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(utfcpp)
|
||||||
|
if(NOT utfcpp_POPULATED)
|
||||||
|
message(STATUS "Downloading utfcpp from ${utfcpp_URL}")
|
||||||
|
FetchContent_Populate(utfcpp)
|
||||||
|
endif()
|
||||||
|
message(STATUS "utfcpp is downloaded to ${utfcpp_SOURCE_DIR}")
|
||||||
|
# add_subdirectory(${utfcpp_SOURCE_DIR} ${utfcpp_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||||
|
include_directories(${utfcpp_SOURCE_DIR})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
download_utfcpp()
|
||||||
@@ -20,9 +20,14 @@ python3 ./python-api-examples/offline-tts.py \
|
|||||||
--vits-tokens=./tokens.txt \
|
--vits-tokens=./tokens.txt \
|
||||||
--output-filename=./generated.wav \
|
--output-filename=./generated.wav \
|
||||||
'liliana, the most beautiful and lovely assistant of our team!'
|
'liliana, the most beautiful and lovely assistant of our team!'
|
||||||
|
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
|
||||||
|
for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
import sherpa_onnx
|
import sherpa_onnx
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@@ -115,7 +120,14 @@ def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
tts = sherpa_onnx.OfflineTts(tts_config)
|
tts = sherpa_onnx.OfflineTts(tts_config)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
audio = tts.generate(args.text, sid=args.sid)
|
audio = tts.generate(args.text, sid=args.sid)
|
||||||
|
end = time.time()
|
||||||
|
elapsed_seconds = end - start
|
||||||
|
audio_duration = len(audio.samples) / audio.sample_rate
|
||||||
|
real_time_factor = elapsed_seconds / audio_duration
|
||||||
|
|
||||||
sf.write(
|
sf.write(
|
||||||
args.output_filename,
|
args.output_filename,
|
||||||
audio.samples,
|
audio.samples,
|
||||||
@@ -124,6 +136,9 @@ def main():
|
|||||||
)
|
)
|
||||||
print(f"Saved to {args.output_filename}")
|
print(f"Saved to {args.output_filename}")
|
||||||
print(f"The text is '{args.text}'")
|
print(f"The text is '{args.text}'")
|
||||||
|
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||||
|
print(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||||
|
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
|||||||
stack-test.cc
|
stack-test.cc
|
||||||
transpose-test.cc
|
transpose-test.cc
|
||||||
unbind-test.cc
|
unbind-test.cc
|
||||||
|
utfcpp-test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
function(sherpa_onnx_add_test source)
|
function(sherpa_onnx_add_test source)
|
||||||
|
|||||||
@@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &punctuations) {
|
const std::string &punctuations, const std::string &language) {
|
||||||
|
InitLanguage(language);
|
||||||
|
InitTokens(tokens);
|
||||||
|
InitLexicon(lexicon);
|
||||||
|
InitPunctuations(punctuations);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
|
||||||
|
const std::string &text) const {
|
||||||
|
switch (language_) {
|
||||||
|
case Language::kEnglish:
|
||||||
|
return ConvertTextToTokenIdsEnglish(text);
|
||||||
|
case Language::kChinese:
|
||||||
|
return ConvertTextToTokenIdsChinese(text);
|
||||||
|
default:
|
||||||
|
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
||||||
|
const std::string &text) const {
|
||||||
|
std::vector<std::string> words = SplitUtf8(text);
|
||||||
|
|
||||||
|
std::vector<int64_t> ans;
|
||||||
|
|
||||||
|
ans.push_back(token2id_.at("sil"));
|
||||||
|
|
||||||
|
for (const auto &w : words) {
|
||||||
|
if (!word2ids_.count(w)) {
|
||||||
|
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &token_ids = word2ids_.at(w);
|
||||||
|
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||||
|
}
|
||||||
|
ans.push_back(token2id_.at("sil"));
|
||||||
|
ans.push_back(token2id_.at("eos"));
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
||||||
|
const std::string &_text) const {
|
||||||
|
std::string text(_text);
|
||||||
|
ToLowerCase(&text);
|
||||||
|
|
||||||
|
std::vector<std::string> words = SplitUtf8(text);
|
||||||
|
|
||||||
|
std::vector<int64_t> ans;
|
||||||
|
for (const auto &w : words) {
|
||||||
|
if (punctuations_.count(w)) {
|
||||||
|
ans.push_back(token2id_.at(w));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!word2ids_.count(w)) {
|
||||||
|
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &token_ids = word2ids_.at(w);
|
||||||
|
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||||
|
if (blank_ != -1) {
|
||||||
|
ans.push_back(blank_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blank_ != -1 && !ans.empty()) {
|
||||||
|
// remove the last blank
|
||||||
|
ans.resize(ans.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Lexicon::InitTokens(const std::string &tokens) {
|
||||||
token2id_ = ReadTokens(tokens);
|
token2id_ = ReadTokens(tokens);
|
||||||
blank_ = token2id_.at(" ");
|
if (token2id_.count(" ")) {
|
||||||
|
blank_ = token2id_.at(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Lexicon::InitLanguage(const std::string &_lang) {
|
||||||
|
std::string lang(_lang);
|
||||||
|
ToLowerCase(&lang);
|
||||||
|
if (lang == "english") {
|
||||||
|
language_ = Language::kEnglish;
|
||||||
|
} else if (lang == "chinese") {
|
||||||
|
language_ = Language::kChinese;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Lexicon::InitLexicon(const std::string &lexicon) {
|
||||||
std::ifstream is(lexicon);
|
std::ifstream is(lexicon);
|
||||||
|
|
||||||
std::string word;
|
std::string word;
|
||||||
@@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
|||||||
}
|
}
|
||||||
word2ids_.insert({std::move(word), std::move(ids)});
|
word2ids_.insert({std::move(word), std::move(ids)});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process punctuations
|
void Lexicon::InitPunctuations(const std::string &punctuations) {
|
||||||
std::vector<std::string> punctuation_list;
|
std::vector<std::string> punctuation_list;
|
||||||
SplitStringToVector(punctuations, " ", false, &punctuation_list);
|
SplitStringToVector(punctuations, " ", false, &punctuation_list);
|
||||||
for (auto &s : punctuation_list) {
|
for (auto &s : punctuation_list) {
|
||||||
@@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
|
|
||||||
const std::string &_text) const {
|
|
||||||
std::string text(_text);
|
|
||||||
ToLowerCase(&text);
|
|
||||||
|
|
||||||
std::vector<std::string> words;
|
|
||||||
SplitStringToVector(text, " ", false, &words);
|
|
||||||
|
|
||||||
std::vector<int64_t> ans;
|
|
||||||
for (auto w : words) {
|
|
||||||
std::vector<int64_t> prefix;
|
|
||||||
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
|
|
||||||
// if w begins with a punctuation
|
|
||||||
prefix.push_back(token2id_.at(std::string(1, w[0])));
|
|
||||||
w = std::string(w.begin() + 1, w.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> suffix;
|
|
||||||
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
|
|
||||||
suffix.push_back(token2id_.at(std::string(1, w.back())));
|
|
||||||
w = std::string(w.begin(), w.end() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!word2ids_.count(w)) {
|
|
||||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto &token_ids = word2ids_.at(w);
|
|
||||||
ans.insert(ans.end(), prefix.begin(), prefix.end());
|
|
||||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
|
||||||
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
|
|
||||||
ans.push_back(blank_);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ans.empty()) {
|
|
||||||
ans.resize(ans.size() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -13,18 +13,40 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
// TODO(fangjun): Refactor it to an abstract class
|
||||||
class Lexicon {
|
class Lexicon {
|
||||||
public:
|
public:
|
||||||
Lexicon(const std::string &lexicon, const std::string &tokens,
|
Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &punctuations);
|
const std::string &punctuations, const std::string &language);
|
||||||
|
|
||||||
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
|
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int64_t> ConvertTextToTokenIdsEnglish(
|
||||||
|
const std::string &text) const;
|
||||||
|
|
||||||
|
std::vector<int64_t> ConvertTextToTokenIdsChinese(
|
||||||
|
const std::string &text) const;
|
||||||
|
|
||||||
|
void InitLanguage(const std::string &lang);
|
||||||
|
void InitTokens(const std::string &tokens);
|
||||||
|
void InitLexicon(const std::string &lexicon);
|
||||||
|
void InitPunctuations(const std::string &punctuations);
|
||||||
|
|
||||||
|
private:
|
||||||
|
enum class Language {
|
||||||
|
kEnglish,
|
||||||
|
kChinese,
|
||||||
|
kUnknown,
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
|
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
|
||||||
std::unordered_set<std::string> punctuations_;
|
std::unordered_set<std::string> punctuations_;
|
||||||
std::unordered_map<std::string, int32_t> token2id_;
|
std::unordered_map<std::string, int32_t> token2id_;
|
||||||
int32_t blank_; // ID for the blank token
|
int32_t blank_ = -1; // ID for the blank token
|
||||||
|
Language language_;
|
||||||
|
//
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
|
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
|
||||||
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
||||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||||
model_->Punctuations()) {}
|
model_->Punctuations(), model_->Language()) {}
|
||||||
|
|
||||||
GeneratedAudio Generate(const std::string &text,
|
GeneratedAudio Generate(const std::string &text,
|
||||||
int64_t sid = 0) const override {
|
int64_t sid = 0) const override {
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
bool AddBlank() const { return add_blank_; }
|
bool AddBlank() const { return add_blank_; }
|
||||||
|
|
||||||
std::string Punctuations() const { return punctuations_; }
|
std::string Punctuations() const { return punctuations_; }
|
||||||
|
std::string Language() const { return language_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
@@ -108,6 +109,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
|
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
|
||||||
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -128,6 +130,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
int32_t add_blank_;
|
int32_t add_blank_;
|
||||||
int32_t n_speakers_;
|
int32_t n_speakers_;
|
||||||
std::string punctuations_;
|
std::string punctuations_;
|
||||||
|
std::string language_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
|
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
|
||||||
@@ -147,4 +150,6 @@ std::string OfflineTtsVitsModel::Punctuations() const {
|
|||||||
return impl_->Punctuations();
|
return impl_->Punctuations();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class OfflineTtsVitsModel {
|
|||||||
bool AddBlank() const;
|
bool AddBlank() const;
|
||||||
|
|
||||||
std::string Punctuations() const;
|
std::string Punctuations() const;
|
||||||
|
std::string Language() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|||||||
@@ -8,12 +8,16 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "source/utf8.h"
|
||||||
|
|
||||||
// This file is copied/modified from
|
// This file is copied/modified from
|
||||||
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
|
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
|
||||||
|
|
||||||
@@ -158,4 +162,57 @@ template bool SplitStringToFloats(const std::string &full, const char *delim,
|
|||||||
bool omit_empty_strings,
|
bool omit_empty_strings,
|
||||||
std::vector<double> *out);
|
std::vector<double> *out);
|
||||||
|
|
||||||
|
std::vector<std::string> SplitUtf8(const std::string &text) {
|
||||||
|
char *begin = const_cast<char *>(text.c_str());
|
||||||
|
char *end = begin + text.size();
|
||||||
|
|
||||||
|
std::vector<std::string> ans;
|
||||||
|
std::string buf;
|
||||||
|
|
||||||
|
while (begin < end) {
|
||||||
|
uint32_t code = utf8::next(begin, end);
|
||||||
|
|
||||||
|
// 1. is punctuation
|
||||||
|
if (std::ispunct(code)) {
|
||||||
|
if (!buf.empty()) {
|
||||||
|
ans.push_back(std::move(buf));
|
||||||
|
}
|
||||||
|
|
||||||
|
char s[5] = {0};
|
||||||
|
utf8::append(code, s);
|
||||||
|
ans.push_back(s);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. is space
|
||||||
|
if (std::isspace(code)) {
|
||||||
|
if (!buf.empty()) {
|
||||||
|
ans.push_back(std::move(buf));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. is alpha
|
||||||
|
if (std::isalpha(code)) {
|
||||||
|
buf.push_back(code);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!buf.empty()) {
|
||||||
|
ans.push_back(std::move(buf));
|
||||||
|
}
|
||||||
|
|
||||||
|
// for others
|
||||||
|
|
||||||
|
char s[5] = {0};
|
||||||
|
utf8::append(code, s);
|
||||||
|
ans.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!buf.empty()) {
|
||||||
|
ans.push_back(std::move(buf));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -119,6 +119,8 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
bool ConvertStringToReal(const std::string &str, T *out);
|
bool ConvertStringToReal(const std::string &str, T *out);
|
||||||
|
|
||||||
|
std::vector<std::string> SplitUtf8(const std::string &text);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||||
|
|||||||
21
sherpa-onnx/csrc/utfcpp-test.cc
Normal file
21
sherpa-onnx/csrc/utfcpp-test.cc
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
// sherpa-onnx/csrc/utfcpp-test.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include <cctype>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
TEST(UTF8, Case1) {
|
||||||
|
std::string hello = "你好, 早上好!世界. hello!。Hallo";
|
||||||
|
std::vector<std::string> ss = SplitUtf8(hello);
|
||||||
|
for (const auto &s : ss) {
|
||||||
|
std::cout << s << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
Reference in New Issue
Block a user