Support Ukrainian VITS models from coqui-ai/TTS (#469)
This commit is contained in:
@@ -431,15 +431,12 @@ void CNonStreamingTextToSpeechDlg::Init() {
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!Exists("./lexicon.txt") && !Exists("./espeak-ng-data/phontab")) {
|
||||
error_message += "Cannot find espeak-ng-data directory or ./lexicon.txt\r\n";
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!Exists("./tokens.txt")) {
|
||||
error_message += "Cannot find ./tokens.txt\r\n";
|
||||
ok = false;
|
||||
}
|
||||
// it is OK to leave lexicon.txt and espeak-ng-data empty
|
||||
// since models using characters don't need them
|
||||
|
||||
if (!ok) {
|
||||
generate_btn_.EnableWindow(FALSE);
|
||||
@@ -470,7 +467,7 @@ void CNonStreamingTextToSpeechDlg::Init() {
|
||||
config.model.vits.model = "./model.onnx";
|
||||
if (Exists("./espeak-ng-data/phontab")) {
|
||||
config.model.vits.data_dir = "./espeak-ng-data";
|
||||
} else {
|
||||
} else if (Exists("./lexicon.txt")) {
|
||||
config.model.vits.lexicon = "./lexicon.txt";
|
||||
}
|
||||
config.model.vits.tokens = "./tokens.txt";
|
||||
|
||||
@@ -41,6 +41,7 @@ set(sources
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
offline-tts-character-frontend.cc
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-wenet-ctc-model.cc
|
||||
offline-whisper-greedy-search-decoder.cc
|
||||
|
||||
191
sherpa-onnx/csrc/offline-tts-character-frontend.cc
Normal file
191
sherpa-onnx/csrc/offline-tts-character-frontend.cc
Normal file
@@ -0,0 +1,191 @@
|
||||
// sherpa-onnx/csrc/offline-tts-character-frontend.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <codecvt>
|
||||
#include <fstream>
|
||||
#include <locale>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::unordered_map<char32_t, int32_t> ReadTokens(std::istream &is) {
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
|
||||
std::unordered_map<char32_t, int32_t> token2id;
|
||||
|
||||
std::string line;
|
||||
|
||||
std::string sym;
|
||||
std::u32string s;
|
||||
int32_t id;
|
||||
while (std::getline(is, line)) {
|
||||
std::istringstream iss(line);
|
||||
iss >> sym;
|
||||
if (iss.eof()) {
|
||||
id = atoi(sym.c_str());
|
||||
sym = " ";
|
||||
} else {
|
||||
iss >> id;
|
||||
}
|
||||
|
||||
// eat the trailing \r\n on windows
|
||||
iss >> std::ws;
|
||||
if (!iss.eof()) {
|
||||
SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Form models from coqui-ai/TTS, we have saved the IDs of the following
|
||||
// symbols in OfflineTtsVitsModelMetaData, so it is safe to skip them here.
|
||||
if (sym == "<PAD>" || sym == "<EOS>" || sym == "<BOS>" || sym == "<BLNK>") {
|
||||
continue;
|
||||
}
|
||||
|
||||
s = conv.from_bytes(sym);
|
||||
if (s.size() != 1) {
|
||||
SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d",
|
||||
line.c_str(), static_cast<int32_t>(s.size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
char32_t c = s[0];
|
||||
|
||||
if (token2id.count(c)) {
|
||||
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
|
||||
sym.c_str(), line.c_str(), token2id.at(c));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
token2id.insert({c, id});
|
||||
}
|
||||
|
||||
return token2id;
|
||||
}
|
||||
|
||||
OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
|
||||
const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data)
|
||||
: meta_data_(meta_data) {
|
||||
std::ifstream is(tokens);
|
||||
token2id_ = ReadTokens(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
|
||||
AAssetManager *mgr, const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data)
|
||||
: meta_data_(meta_data) {
|
||||
auto buf = ReadFile(mgr, tokens);
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
token2id_ = ReadTokens(is);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
const std::string &_text, const std::string &voice /*= ""*/) const {
|
||||
// see
|
||||
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
|
||||
int32_t use_eos_bos = meta_data_.use_eos_bos;
|
||||
int32_t bos_id = meta_data_.bos_id;
|
||||
int32_t eos_id = meta_data_.eos_id;
|
||||
int32_t blank_id = meta_data_.blank_id;
|
||||
int32_t add_blank = meta_data_.add_blank;
|
||||
|
||||
std::string text(_text.size(), 0);
|
||||
std::transform(_text.begin(), _text.end(), text.begin(),
|
||||
[](auto c) { return std::tolower(c); });
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
|
||||
std::u32string s = conv.from_bytes(text);
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
|
||||
std::vector<int64_t> this_sentence;
|
||||
if (add_blank) {
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(bos_id);
|
||||
}
|
||||
|
||||
this_sentence.push_back(blank_id);
|
||||
|
||||
for (char32_t c : s) {
|
||||
if (token2id_.count(c)) {
|
||||
this_sentence.push_back(token2id_.at(c));
|
||||
this_sentence.push_back(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Skip unknown character. Unicode codepoint: \\U+%04x.",
|
||||
static_cast<uint32_t>(c));
|
||||
}
|
||||
|
||||
if (c == '.' || c == ':' || c == '?' || c == '!') {
|
||||
// end of a sentence
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(eos_id);
|
||||
}
|
||||
|
||||
ans.push_back(std::move(this_sentence));
|
||||
|
||||
// re-initialize this_sentence
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(bos_id);
|
||||
}
|
||||
this_sentence.push_back(blank_id);
|
||||
}
|
||||
}
|
||||
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(eos_id);
|
||||
}
|
||||
|
||||
if (this_sentence.size() > 1 + use_eos_bos) {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
}
|
||||
} else {
|
||||
// not adding blank
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(bos_id);
|
||||
}
|
||||
|
||||
for (char32_t c : s) {
|
||||
if (token2id_.count(c)) {
|
||||
this_sentence.push_back(token2id_.at(c));
|
||||
}
|
||||
|
||||
if (c == '.' || c == ':' || c == '?' || c == '!') {
|
||||
// end of a sentence
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(eos_id);
|
||||
}
|
||||
|
||||
ans.push_back(std::move(this_sentence));
|
||||
|
||||
// re-initialize this_sentence
|
||||
if (use_eos_bos) {
|
||||
this_sentence.push_back(bos_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this_sentence.size() > 1) {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
54
sherpa-onnx/csrc/offline-tts-character-frontend.h
Normal file
54
sherpa-onnx/csrc/offline-tts-character-frontend.h
Normal file
@@ -0,0 +1,54 @@
|
||||
// sherpa-onnx/csrc/offline-tts-character-frontend.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
|
||||
public:
|
||||
OfflineTtsCharacterFrontend(const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTtsCharacterFrontend(AAssetManager *mgr, const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
|
||||
#endif
|
||||
/** Convert a string to token IDs.
|
||||
*
|
||||
* @param text The input text.
|
||||
* Example 1: "This is the first sample sentence; this is the
|
||||
* second one." Example 2: "这是第一句。这是第二句。"
|
||||
* @param voice Optional. It is for espeak-ng.
|
||||
*
|
||||
* @return Return a vector-of-vector of token IDs. Each subvector contains
|
||||
* a sentence that can be processed independently.
|
||||
* If a frontend does not support splitting the text into
|
||||
* sentences, the resulting vector contains only one subvector.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
OfflineTtsVitsModelMetaData meta_data_;
|
||||
std::unordered_map<char32_t, int32_t> token2id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
|
||||
@@ -116,7 +117,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (meta_data.add_blank && config_.model.vits.data_dir.empty()) {
|
||||
// TODO(fangjun): add blank inside the frontend, not here
|
||||
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
|
||||
meta_data.frontend != "characters") {
|
||||
for (auto &k : x) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
@@ -195,12 +198,22 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
void InitFrontend(AAssetManager *mgr) {
|
||||
const auto &meta_data = model_->GetMetaData();
|
||||
|
||||
if ((meta_data.is_piper || meta_data.is_coqui) &&
|
||||
!config_.model.vits.data_dir.empty()) {
|
||||
if (meta_data.frontend == "characters") {
|
||||
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
|
||||
mgr, config_.model.vits.tokens, meta_data);
|
||||
} else if ((meta_data.is_piper || meta_data.is_coqui) &&
|
||||
!config_.model.vits.data_dir.empty()) {
|
||||
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
|
||||
mgr, config_.model.vits.tokens, config_.model.vits.data_dir,
|
||||
meta_data);
|
||||
} else {
|
||||
if (config_.model.vits.lexicon.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Not a model using characters as modeling unit. Please provide "
|
||||
"--vits-lexicon if you leave --vits-data-dir empty");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
frontend_ = std::make_unique<Lexicon>(
|
||||
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
meta_data.punctuations, meta_data.language, config_.model.debug);
|
||||
@@ -211,12 +224,21 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
void InitFrontend() {
|
||||
const auto &meta_data = model_->GetMetaData();
|
||||
|
||||
if ((meta_data.is_piper || meta_data.is_coqui) &&
|
||||
!config_.model.vits.data_dir.empty()) {
|
||||
if (meta_data.frontend == "characters") {
|
||||
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
|
||||
config_.model.vits.tokens, meta_data);
|
||||
} else if ((meta_data.is_piper || meta_data.is_coqui) &&
|
||||
!config_.model.vits.data_dir.empty()) {
|
||||
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
|
||||
config_.model.vits.tokens, config_.model.vits.data_dir,
|
||||
model_->GetMetaData());
|
||||
} else {
|
||||
if (config_.model.vits.lexicon.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Not a model using characters as modeling unit. Please provide "
|
||||
"--vits-lexicon if you leave --vits-data-dir empty");
|
||||
exit(-1);
|
||||
}
|
||||
frontend_ = std::make_unique<Lexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
meta_data.punctuations, meta_data.language, config_.model.debug);
|
||||
|
||||
@@ -44,19 +44,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data_dir.empty()) {
|
||||
if (lexicon.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please provide --vits-lexicon if you leave --vits-data-dir empty");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(lexicon)) {
|
||||
SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
} else {
|
||||
if (!data_dir.empty()) {
|
||||
if (!FileExists(data_dir + "/phontab")) {
|
||||
SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test",
|
||||
data_dir.c_str());
|
||||
|
||||
@@ -10,15 +10,14 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// If you are not sure what each field means, please
|
||||
// have a look of the Python file in the model directory that
|
||||
// you have downloaded.
|
||||
struct OfflineTtsVitsModelMetaData {
|
||||
int32_t sample_rate;
|
||||
int32_t sample_rate = 0;
|
||||
int32_t add_blank = 0;
|
||||
int32_t num_speakers = 0;
|
||||
|
||||
std::string punctuations;
|
||||
std::string language;
|
||||
std::string voice;
|
||||
|
||||
bool is_piper = false;
|
||||
bool is_coqui = false;
|
||||
|
||||
@@ -27,6 +26,12 @@ struct OfflineTtsVitsModelMetaData {
|
||||
int32_t bos_id = 0;
|
||||
int32_t eos_id = 0;
|
||||
int32_t use_eos_bos = 0;
|
||||
int32_t pad_id = 0;
|
||||
|
||||
std::string punctuations;
|
||||
std::string language;
|
||||
std::string voice;
|
||||
std::string frontend; // characters
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -87,13 +87,18 @@ class OfflineTtsVitsModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
|
||||
"punctuation", "");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", "");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.frontend, "frontend",
|
||||
"");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos,
|
||||
"use_eos_bos", 0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.pad_id, "pad_id", 0);
|
||||
|
||||
std::string comment;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(comment, "comment");
|
||||
@@ -142,16 +147,25 @@ class OfflineTtsVitsModel::Impl {
|
||||
Ort::Value sid_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &sid, 1, &sid_shape, 1);
|
||||
|
||||
int64_t lang_id_shape = 1;
|
||||
int64_t lang_id = 0;
|
||||
Ort::Value lang_id_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &lang_id, 1, &lang_id_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(4);
|
||||
inputs.reserve(5);
|
||||
inputs.push_back(std::move(x));
|
||||
inputs.push_back(std::move(x_length));
|
||||
inputs.push_back(std::move(scales_tensor));
|
||||
|
||||
if (input_names_.size() == 4 && input_names_.back() == "sid") {
|
||||
if (input_names_.size() >= 4 && input_names_[3] == "sid") {
|
||||
inputs.push_back(std::move(sid_tensor));
|
||||
}
|
||||
|
||||
if (input_names_.size() >= 5 && input_names_[4] == "langid") {
|
||||
inputs.push_back(std::move(lang_id_tensor));
|
||||
}
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
@@ -123,7 +123,6 @@ static std::vector<int64_t> CoquiPhonemesToIds(
|
||||
int32_t blank_id = meta_data.blank_id;
|
||||
int32_t add_blank = meta_data.add_blank;
|
||||
int32_t comma_id = token2id.at(',');
|
||||
SHERPA_ONNX_LOGE("comma id: %d", comma_id);
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
if (add_blank) {
|
||||
|
||||
Reference in New Issue
Block a user