Export the English TTS model from MeloTTS (#1509)

This commit is contained in:
Fangjun Kuang
2024-11-04 07:54:19 +08:00
committed by GitHub
parent 6ee8c99c5d
commit 4eeb336f59
11 changed files with 369 additions and 26 deletions

View File

@@ -152,10 +152,6 @@
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
do { \
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
if (value.empty()) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
exit(-1); \
} \
\
dst = std::move(value); \
} while (0)

View File

@@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl {
}
}
Impl(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
}
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
@@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl {
s = std::regex_replace(s, punct_re4, "!");
std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (jieba_) {
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
} else {
words = SplitUtf8(text);
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
if (debug_) {
fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str());
fprintf(stderr, "Input text in bytes (lowercase):");
for (uint8_t c : text) {
fprintf(stderr, " %02x", c);
}
fprintf(stderr, "\n");
fprintf(stderr, "After splitting to words:");
for (const auto &w : words) {
fprintf(stderr, " %s", w.c_str());
}
fprintf(stderr, "\n");
}
}
std::vector<TokenIDs> ans;
@@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl {
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}
// For Chinese+English MeloTTS
word2ids_[""] = word2ids_[""];
word2ids_[""] = word2ids_[""];
}
@@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, meta_data, debug)) {}
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);

View File

@@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
const std::string &unused_voice = "") const override;

View File

@@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.is_melo_tts && meta_data.language == "English") {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
model_->GetMetaData(), config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,

View File

@@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl {
}
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
if (meta_data_.num_speakers == 1) {
// For MeloTTS, we hardcode sid to the one contained in the meta data
sid = meta_data_.speaker_id;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

View File

@@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
// For other versions, we may need to change it
#if ORT_API_VERSION >= 12
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
return v.get();
return v ? v.get() : "";
#else
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
std::string ans = v;
std::string ans = v ? v : "";
allocator->Free(allocator, v);
return ans;
#endif