Add C++ runtime for MeloTTS (#1138)

This commit is contained in:
Fangjun Kuang
2024-07-16 15:55:02 +08:00
committed by GitHub
parent 95485411fa
commit 960eb7529e
51 changed files with 693 additions and 156 deletions

View File

@@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
void SherpaOnnxOfflineRecognizerSetConfig(
const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config){
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);
recognizer->impl->SetConfig(recognizer_config);
recognizer->impl->SetConfig(recognizer_config);
}
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
@@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
pText[text.size()] = 0;
r->text = pText;
//lang
// lang
const auto &lang = result.lang;
char *c_lang = new char[lang.size() + 1];
std::copy(lang.begin(), lang.end(), c_lang);
@@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
}
delete[] r->matches;
delete r;
};
}
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,

View File

@@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;
// Pointer to continuous memory which holds timestamps
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;
@@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
*/
const char *json;
//return recognized language
// return recognized language
const char *lang;
} SherpaOnnxOfflineRecognizerResult;
/// Get the result of the offline stream.

View File

@@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
jieba-lexicon.cc
lexicon.cc
melo-tts-lexicon.cc
offline-tts-character-frontend.cc
offline-tts-frontend.cc
offline-tts-impl.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc

View File

@@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) {
std::vector<std::string> words;
std::vector<cppjieba::Word> jiebawords;
std::string s = "他来到了网易杭研大厦";
std::string s = "他来到了网易杭研大厦。How are you?";
std::cout << s << std::endl;
std::cout << "[demo] Cut With HMM" << std::endl;
jieba.Cut(s, words, true);

View File

@@ -17,6 +17,7 @@ namespace sherpa_onnx {
// implemented in ./lexicon.cc
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
@@ -53,8 +54,7 @@ class JiebaLexicon::Impl {
}
}
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
const std::string &text) const {
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const {
// see
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
std::regex punct_re{"|、|"};
@@ -87,7 +87,7 @@ class JiebaLexicon::Impl {
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
int32_t blank = token2id_.at(" ");
@@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon,
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
}

View File

@@ -10,11 +10,6 @@
#include <unordered_map>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
@@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
#if __ANDROID_API__ >= 9
JiebaLexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;

View File

@@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
}
#endif
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
switch (language_) {
case Language::kChinese:
@@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
return {};
}
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
@@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
fprintf(stderr, "\n");
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
int32_t blank = -1;
@@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
return ans;
}
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
@@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
int32_t blank = token2id_.at(" ");
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
for (const auto &w : words) {

View File

@@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend {
const std::string &language, bool debug = false);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese(
const std::string &text) const;
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
std::vector<TokenIDs> ConvertTextToTokenIdsChinese(
const std::string &text) const;
void InitLanguage(const std::string &lang);

View File

@@ -0,0 +1,266 @@
// sherpa-onnx/csrc/melo-tts-lexicon.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include <fstream>
#include <regex> // NOLINT
#include <utility>
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
// implemented in ./lexicon.cc
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
std::vector<int32_t> ConvertTokensToIds(
const std::unordered_map<std::string, int32_t> &token2id,
const std::vector<std::string> &tokens);
class MeloTtsLexicon::Impl {
public:
Impl(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
}
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
std::regex punct_re{"|、|"};
std::string s = std::regex_replace(text, punct_re, ",");
std::regex punct_re2("");
s = std::regex_replace(s, punct_re2, ".");
std::regex punct_re3("");
s = std::regex_replace(s, punct_re3, "?");
std::regex punct_re4("");
s = std::regex_replace(s, punct_re4, "!");
std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
std::vector<TokenIDs> ans;
TokenIDs this_sentence;
int32_t blank = token2id_.at("_");
for (const auto &w : words) {
auto ids = ConvertWordToIds(w);
if (ids.tokens.empty()) {
SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
continue;
}
this_sentence.tokens.insert(this_sentence.tokens.end(),
ids.tokens.begin(), ids.tokens.end());
this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
ids.tones.end());
if (w == "." || w == "!" || w == "?" || w == ",") {
ans.push_back(std::move(this_sentence));
this_sentence = {};
}
} // for (const auto &w : words)
if (!this_sentence.tokens.empty()) {
ans.push_back(std::move(this_sentence));
}
return ans;
}
private:
TokenIDs ConvertWordToIds(const std::string &w) const {
if (word2ids_.count(w)) {
return word2ids_.at(w);
}
if (token2id_.count(w)) {
return {{token2id_.at(w)}, {0}};
}
TokenIDs ans;
std::vector<std::string> words = SplitUtf8(w);
for (const auto &word : words) {
if (word2ids_.count(word)) {
auto ids = ConvertWordToIds(word);
ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
ids.tokens.end());
ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
}
}
return ans;
}
void InitTokens(std::istream &is) {
token2id_ = ReadTokens(is);
token2id_[" "] = token2id_["_"];
std::vector<std::pair<std::string, std::string>> puncts = {
{",", ""}, {".", ""}, {"!", ""}, {"?", ""}};
for (const auto &p : puncts) {
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
token2id_[p.second] = token2id_[p.first];
}
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
token2id_[p.first] = token2id_[p.second];
}
}
if (!token2id_.count("") && token2id_.count("")) {
token2id_[""] = token2id_[""];
}
}
void InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::vector<std::string> phone_list;
std::vector<int64_t> tone_list;
std::string line;
std::string phone;
int32_t line_num = 0;
while (std::getline(is, line)) {
++line_num;
std::istringstream iss(line);
token_list.clear();
phone_list.clear();
tone_list.clear();
iss >> word;
ToLowerCase(&word);
if (word2ids_.count(word)) {
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
continue;
}
while (iss >> phone) {
token_list.push_back(std::move(phone));
}
if ((token_list.size() & 1) != 0) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
int32_t num_phones = token_list.size() / 2;
phone_list.reserve(num_phones);
tone_list.reserve(num_phones);
for (int32_t i = 0; i != num_phones; ++i) {
phone_list.push_back(std::move(token_list[i]));
tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
if (tone_list.back() < 0 || tone_list.back() > 50) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
}
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
if (ids.empty()) {
continue;
}
if (ids.size() != num_phones) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
std::vector<int64_t> ids64{ids.begin(), ids.end()};
word2ids_.insert(
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}
word2ids_[""] = word2ids_[""];
word2ids_[""] = word2ids_[""];
}
private:
// lexicon.txt is saved in word2ids_
std::unordered_map<std::string, TokenIDs> word2ids_;
// tokens.txt is saved in token2id_
std::unordered_map<std::string, int32_t> token2id_;
OfflineTtsVitsModelMetaData meta_data_;
std::unique_ptr<cppjieba::Jieba> jieba_;
bool debug_ = false;
};
MeloTtsLexicon::~MeloTtsLexicon() = default;
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,36 @@
// sherpa-onnx/csrc/melo-tts-lexicon.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
namespace sherpa_onnx {
class MeloTtsLexicon : public OfflineTtsFrontend {
public:
~MeloTtsLexicon() override;
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_

View File

@@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
#endif
std::vector<std::vector<int64_t>>
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
const std::string &_text, const std::string & /*voice = ""*/) const {
// see
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
@@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
std::u32string s = conv.from_bytes(text);
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> this_sentence;
if (add_blank) {

View File

@@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
* If a frontend does not support splitting the text into
* sentences, the resulting vector contains only one subvector.
*/
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:

View File

@@ -0,0 +1,34 @@
// sherpa-onnx/csrc/offline-tts-frontend.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include <sstream>
#include <string>
namespace sherpa_onnx {
std::string TokenIDs::ToString() const {
std::ostringstream os;
os << "TokenIDs(";
os << "tokens=[";
std::string sep;
for (auto i : tokens) {
os << sep << i;
sep = ", ";
}
os << "], ";
os << "tones=[";
sep = {};
for (auto i : tones) {
os << sep << i;
sep = ", ";
}
os << "]";
os << ")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -8,8 +8,28 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
struct TokenIDs {
TokenIDs() = default;
/*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT
: tokens{tokens} {}
TokenIDs(const std::vector<int64_t> &tokens,
const std::vector<int64_t> &tones)
: tokens{tokens}, tones{tones} {}
std::string ToString() const;
std::vector<int64_t> tokens;
// Used only in MeloTTS
std::vector<int64_t> tones;
};
class OfflineTtsFrontend {
public:
virtual ~OfflineTtsFrontend() = default;
@@ -26,7 +46,7 @@ class OfflineTtsFrontend {
* If a frontend does not support splitting the text into sentences,
* the resulting vector contains only one subvector.
*/
virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
virtual std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const = 0;
};

View File

@@ -22,6 +22,7 @@
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
std::vector<std::vector<int64_t>> x =
std::vector<TokenIDs> token_ids =
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
if (x.empty() || (x.size() == 1 && x[0].empty())) {
if (token_ids.empty() ||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}
std::vector<std::vector<int64_t>> x;
std::vector<std::vector<int64_t>> tones;
x.reserve(token_ids.size());
for (auto &i : token_ids) {
x.push_back(std::move(i.tokens));
}
if (!token_ids[0].tones.empty()) {
tones.reserve(token_ids.size());
for (auto &i : token_ids) {
tones.push_back(std::move(i.tones));
}
}
// TODO(fangjun): add blank inside the frontend, not here
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
meta_data.frontend != "characters") {
for (auto &k : x) {
k = AddBlank(k);
}
for (auto &k : tones) {
k = AddBlank(k);
}
}
int32_t x_size = static_cast<int32_t>(x.size());
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
auto ans = Process(x, sid, speed);
auto ans = Process(x, tones, sid, speed);
if (callback) {
callback(ans.samples.data(), ans.samples.size(), 1.0);
}
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
// the input text is too long, we process sentences within it in batches
// to avoid OOM. Batch size is config_.max_num_sentences
std::vector<std::vector<int64_t>> batch;
std::vector<std::vector<int64_t>> batch_x;
std::vector<std::vector<int64_t>> batch_tones;
int32_t batch_size = config_.max_num_sentences;
batch.reserve(config_.max_num_sentences);
batch_x.reserve(config_.max_num_sentences);
batch_tones.reserve(config_.max_num_sentences);
int32_t num_batches = x_size / batch_size;
if (config_.model.debug) {
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
int32_t k = 0;
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch.clear();
batch_x.clear();
batch_tones.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
}
auto audio = Process(batch, sid, speed);
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
batch.clear();
batch_x.clear();
batch_tones.clear();
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
++k;
}
if (!batch.empty()) {
auto audio = Process(batch, sid, speed);
if (!batch_x.empty()) {
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
if (meta_data.frontend == "characters") {
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
config_.model.vits.tokens, meta_data);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
meta_data.is_melo_tts) {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
const std::vector<std::vector<int64_t>> &tones,
int32_t sid, float speed) const {
int32_t num_tokens = 0;
for (const auto &k : tokens) {
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
x.insert(x.end(), k.begin(), k.end());
}
std::vector<int64_t> tone_list;
if (!tones.empty()) {
tone_list.reserve(num_tokens);
for (const auto &k : tones) {
tone_list.insert(tone_list.end(), k.begin(), k.end());
}
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
Ort::Value tones_tensor{nullptr};
if (!tones.empty()) {
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
tone_list.size(), x_shape.data(),
x_shape.size());
}
Ort::Value audio{nullptr};
if (tones.empty()) {
audio = model_->Run(std::move(x_tensor), sid, speed);
} else {
audio =
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
}
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();

View File

@@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData {
bool is_piper = false;
bool is_coqui = false;
bool is_icefall = false;
bool is_melo_tts = false;
// for Chinese TTS models from
// https://github.com/Plachtaa/VITS-fast-fine-tuning
@@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData {
int32_t use_eos_bos = 0;
int32_t pad_id = 0;
// for melo tts
int32_t speaker_id = 0;
int32_t version = 0;
std::string punctuations;
std::string language;
std::string voice;

View File

@@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl {
return RunVits(std::move(x), sid, speed);
}
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
if (x_shape[0] != 1) {
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
static_cast<int32_t>(x_shape[0]));
exit(-1);
}
int64_t len = x_shape[1];
int64_t len_shape = 1;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
int64_t scale_shape = 1;
float noise_scale = config_.vits.noise_scale;
float length_scale = config_.vits.length_scale;
float noise_scale_w = config_.vits.noise_scale_w;
if (speed != 1 && speed > 0) {
length_scale = 1. / speed;
}
Ort::Value noise_scale_tensor =
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
memory_info, &length_scale, 1, &scale_shape, 1);
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
memory_info, &noise_scale_w, 1, &scale_shape, 1);
Ort::Value sid_tensor =
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
std::vector<Ort::Value> inputs;
inputs.reserve(7);
inputs.push_back(std::move(x));
inputs.push_back(std::move(x_length));
inputs.push_back(std::move(tones));
inputs.push_back(std::move(sid_tensor));
inputs.push_back(std::move(noise_scale_tensor));
inputs.push_back(std::move(length_scale_tensor));
inputs.push_back(std::move(noise_scale_w_tensor));
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
private:
@@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl {
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id",
0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0);
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
"punctuation", "");
@@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl {
if (comment.find("icefall") != std::string::npos) {
meta_data_.is_icefall = true;
}
if (comment.find("melo") != std::string::npos) {
meta_data_.is_melo_tts = true;
int32_t expected_version = 2;
if (meta_data_.version < expected_version) {
SHERPA_ONNX_LOGE(
"Please download the latest MeloTTS model and retry. Current "
"version: %d. Expected version: %d",
meta_data_.version, expected_version);
exit(-1);
}
// NOTE(fangjun):
// version 0 is the first version
// version 2: add jieba=1 to the metadata
}
}
Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
@@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
return impl_->Run(std::move(x), sid, speed);
}
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
int64_t sid /*= 0*/,
float speed /*= 1.0*/) {
return impl_->Run(std::move(x), std::move(tones), sid, speed);
}
const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
return impl_->GetMetaData();
}

View File

@@ -40,6 +40,10 @@ class OfflineTtsVitsModel {
*/
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
// This is for MeloTTS
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
float speed = 1.0);
const OfflineTtsVitsModelMetaData &GetMetaData() const;
private:

View File

@@ -5,8 +5,8 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
@@ -36,7 +36,6 @@ class OfflineWhisperDecoder {
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
};
} // namespace sherpa_onnx

View File

@@ -12,7 +12,8 @@
namespace sherpa_onnx {
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
void OfflineWhisperGreedySearchDecoder::SetConfig(
const OfflineWhisperModelConfig &config) {
config_ = config;
}
@@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
ans[0].lang = "";
}
ans[0].tokens = std::move(predicted_tokens);

View File

@@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) {
}
}
template <typename T /*= float*/>
void Print1D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
const float *d = v->GetTensorData<float>();
const T *d = v->GetTensorData<T>();
std::ostringstream os;
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
fprintf(stderr, "%.3f ", d[i]);
os << *d << " ";
}
fprintf(stderr, "\n");
os << "\n";
fprintf(stderr, "%s\n", os.str().c_str());
}
template void Print1D<int64_t>(Ort::Value *v);
template void Print1D<float>(Ort::Value *v);
template <typename T /*= float*/>
void Print2D(Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();

View File

@@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
Ort::Value View(Ort::Value *v);
// Print a 1-D tensor to stderr
template <typename T = float>
void Print1D(Ort::Value *v);
// Print a 2-D tensor to stderr

View File

@@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
}
#endif
std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string &voice /*= ""*/) const {
piper::eSpeakPhonemeConfig config;
@@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
piper::phonemize_eSpeak(text, config, phonemes);
}
std::vector<std::vector<int64_t>> ans;
std::vector<TokenIDs> ans;
std::vector<int64_t> phoneme_ids;

View File

@@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
const OfflineTtsVitsModelMetaData &meta_data);
#endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override;
private:

View File

@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api.ReleaseStatus(status);
}
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
const std::string &provider_str,
static Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr) {
Provider p = StringToProvider(provider_str);
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
case Provider::kTRT: {
if (provider_config == nullptr) {
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
"Must be extended for offline and others");
SHERPA_ONNX_LOGE(
"Tensorrt support for Online models ony,"
"Must be extended for offline and others");
exit(1);
}
auto trt_config = provider_config->trt_config;
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::to_string(trt_config.trt_max_partition_iterations);
auto trt_min_subgraph_size =
std::to_string(trt_config.trt_min_subgraph_size);
auto trt_fp16_enable =
std::to_string(trt_config.trt_fp16_enable);
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
auto trt_detailed_build_log =
std::to_string(trt_config.trt_detailed_build_log);
auto trt_engine_cache_enable =
std::to_string(trt_config.trt_engine_cache_enable);
auto trt_timing_cache_enable =
std::to_string(trt_config.trt_timing_cache_enable);
auto trt_dump_subgraphs =
std::to_string(trt_config.trt_dump_subgraphs);
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
std::vector<TrtPairs> trt_options = {
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
};
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations",
trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
if (provider_config != nullptr) {
options.device_id = provider_config->device;
options.cudnn_conv_algo_search =
OrtCudnnConvAlgoSearch(provider_config->cuda_config
.cudnn_conv_algo_search);
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
provider_config->cuda_config.cudnn_conv_algo_search);
} else {
options.device_id = 0;
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type) {
const std::string &model_type) {
/*
Transducer models : Only encoder will run with tensorrt,
decoder and joiner will run with cuda
*/
if(config.provider_config.provider == "trt" &&
if (config.provider_config.provider == "trt" &&
(model_type == "decoder" || model_type == "joiner")) {
return GetSessionOptionsImpl(config.num_threads,
"cuda", &config.provider_config);
return GetSessionOptionsImpl(config.num_threads, "cuda",
&config.provider_config);
}
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {

View File

@@ -5,6 +5,8 @@
#ifndef SHERPA_ONNX_CSRC_SESSION_H_
#define SHERPA_ONNX_CSRC_SESSION_H_
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
@@ -25,7 +27,7 @@ namespace sherpa_onnx {
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type);
const std::string &model_type);
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);

View File

@@ -6,6 +6,7 @@
#include <algorithm>
#include <unordered_map>
#include <utility>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/macros.h"

View File

@@ -11,7 +11,7 @@
namespace sherpa_onnx {
TEST(UTF8, Case1) {
std::string hello = "你好, 早上好!世界. hello!。Hallo";
std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?";
std::vector<std::string> ss = SplitUtf8(hello);
for (const auto &s : ss) {
std::cout << s << "\n";