Support Chinese vits models (#368)

This commit is contained in:
Fangjun Kuang
2023-10-18 10:19:10 +08:00
committed by GitHub
parent 9efe69720d
commit 1ee79e3ff5
16 changed files with 326 additions and 62 deletions

View File

@@ -331,6 +331,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
stack-test.cc
transpose-test.cc
unbind-test.cc
utfcpp-test.cc
)
function(sherpa_onnx_add_test source)

View File

@@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
}
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
const std::string &punctuations, const std::string &language) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);
InitPunctuations(punctuations);
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
case Language::kEnglish:
return ConvertTextToTokenIdsEnglish(text);
case Language::kChinese:
return ConvertTextToTokenIdsChinese(text);
default:
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
exit(-1);
}
return {};
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &text) const {
std::vector<std::string> words = SplitUtf8(text);
std::vector<int64_t> ans;
ans.push_back(token2id_.at("sil"));
for (const auto &w : words) {
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
return ans;
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words = SplitUtf8(text);
std::vector<int64_t> ans;
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(token2id_.at(w));
continue;
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
}
if (blank_ != -1 && !ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}
return ans;
}
void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
blank_ = token2id_.at(" ");
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}
void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
ToLowerCase(&lang);
if (lang == "english") {
language_ = Language::kEnglish;
} else if (lang == "chinese") {
language_ = Language::kChinese;
} else {
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
exit(-1);
}
}
void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);
std::string word;
@@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
word2ids_.insert({std::move(word), std::move(ids)});
}
}
// process punctuations
void Lexicon::InitPunctuations(const std::string &punctuations) {
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
@@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);
std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}
std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
ans.push_back(blank_);
}
if (!ans.empty()) {
ans.resize(ans.size() - 1);
}
return ans;
}
} // namespace sherpa_onnx

View File

@@ -13,18 +13,40 @@
namespace sherpa_onnx {
// TODO(fangjun): Refactor it to an abstract class
class Lexicon {
public:
Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations);
const std::string &punctuations, const std::string &language);
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
private:
std::vector<int64_t> ConvertTextToTokenIdsEnglish(
const std::string &text) const;
std::vector<int64_t> ConvertTextToTokenIdsChinese(
const std::string &text) const;
void InitLanguage(const std::string &lang);
void InitTokens(const std::string &tokens);
void InitLexicon(const std::string &lexicon);
void InitPunctuations(const std::string &punctuations);
private:
enum class Language {
kEnglish,
kChinese,
kUnknown,
};
private:
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
std::unordered_set<std::string> punctuations_;
std::unordered_map<std::string, int32_t> token2id_;
int32_t blank_; // ID for the blank token
int32_t blank_ = -1; // ID for the blank token
Language language_;
//
};
} // namespace sherpa_onnx

View File

@@ -21,7 +21,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}
model_->Punctuations(), model_->Language()) {}
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {

View File

@@ -84,6 +84,7 @@ class OfflineTtsVitsModel::Impl {
bool AddBlank() const { return add_blank_; }
std::string Punctuations() const { return punctuations_; }
std::string Language() const { return language_; }
private:
void Init(void *model_data, size_t model_data_length) {
@@ -108,6 +109,7 @@ class OfflineTtsVitsModel::Impl {
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
}
private:
@@ -128,6 +130,7 @@ class OfflineTtsVitsModel::Impl {
int32_t add_blank_;
int32_t n_speakers_;
std::string punctuations_;
std::string language_;
};
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
@@ -147,4 +150,6 @@ std::string OfflineTtsVitsModel::Punctuations() const {
return impl_->Punctuations();
}
std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
} // namespace sherpa_onnx

View File

@@ -38,6 +38,7 @@ class OfflineTtsVitsModel {
bool AddBlank() const;
std::string Punctuations() const;
std::string Language() const;
private:
class Impl;

View File

@@ -8,12 +8,16 @@
#include <assert.h>
#include <algorithm>
#include <cctype>
#include <limits>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "source/utf8.h"
// This file is copied/modified from
// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc
@@ -158,4 +162,57 @@ template bool SplitStringToFloats(const std::string &full, const char *delim,
bool omit_empty_strings,
std::vector<double> *out);
std::vector<std::string> SplitUtf8(const std::string &text) {
char *begin = const_cast<char *>(text.c_str());
char *end = begin + text.size();
std::vector<std::string> ans;
std::string buf;
while (begin < end) {
uint32_t code = utf8::next(begin, end);
// 1. is punctuation
if (std::ispunct(code)) {
if (!buf.empty()) {
ans.push_back(std::move(buf));
}
char s[5] = {0};
utf8::append(code, s);
ans.push_back(s);
continue;
}
// 2. is space
if (std::isspace(code)) {
if (!buf.empty()) {
ans.push_back(std::move(buf));
}
continue;
}
// 3. is alpha
if (std::isalpha(code)) {
buf.push_back(code);
continue;
}
if (!buf.empty()) {
ans.push_back(std::move(buf));
}
// for others
char s[5] = {0};
utf8::append(code, s);
ans.push_back(s);
}
if (!buf.empty()) {
ans.push_back(std::move(buf));
}
return ans;
}
} // namespace sherpa_onnx

View File

@@ -119,6 +119,8 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
template <typename T>
bool ConvertStringToReal(const std::string &str, T *out);
std::vector<std::string> SplitUtf8(const std::string &text);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_

View File

@@ -0,0 +1,21 @@
// sherpa-onnx/csrc/utfcpp-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include <cctype>
#include <string>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
TEST(UTF8, Case1) {
std::string hello = "你好, 早上好!世界. hello!。Hallo";
std::vector<std::string> ss = SplitUtf8(hello);
for (const auto &s : ss) {
std::cout << s << "\n";
}
}
} // namespace sherpa_onnx