Export the English TTS model from MeloTTS (#1509)
This commit is contained in:
@@ -152,10 +152,6 @@
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \
|
||||
do { \
|
||||
auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \
|
||||
if (value.empty()) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
dst = std::move(value); \
|
||||
} while (0)
|
||||
|
||||
@@ -48,6 +48,20 @@ class MeloTtsLexicon::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
Impl(const std::string &lexicon, const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||
: meta_data_(meta_data), debug_(debug) {
|
||||
{
|
||||
std::ifstream is(tokens);
|
||||
InitTokens(is);
|
||||
}
|
||||
|
||||
{
|
||||
std::ifstream is(lexicon);
|
||||
InitLexicon(is);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
|
||||
std::string text = ToLowerCase(_text);
|
||||
// see
|
||||
@@ -65,21 +79,39 @@ class MeloTtsLexicon::Impl {
|
||||
s = std::regex_replace(s, punct_re4, "!");
|
||||
|
||||
std::vector<std::string> words;
|
||||
bool is_hmm = true;
|
||||
jieba_->Cut(text, words, is_hmm);
|
||||
if (jieba_) {
|
||||
bool is_hmm = true;
|
||||
jieba_->Cut(text, words, is_hmm);
|
||||
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
|
||||
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
|
||||
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
|
||||
|
||||
std::ostringstream os;
|
||||
std::string sep = "";
|
||||
for (const auto &w : words) {
|
||||
os << sep << w;
|
||||
sep = "_";
|
||||
std::ostringstream os;
|
||||
std::string sep = "";
|
||||
for (const auto &w : words) {
|
||||
os << sep << w;
|
||||
sep = "_";
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
}
|
||||
} else {
|
||||
words = SplitUtf8(text);
|
||||
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
if (debug_) {
|
||||
fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str());
|
||||
fprintf(stderr, "Input text in bytes (lowercase):");
|
||||
for (uint8_t c : text) {
|
||||
fprintf(stderr, " %02x", c);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "After splitting to words:");
|
||||
for (const auto &w : words) {
|
||||
fprintf(stderr, " %s", w.c_str());
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ans;
|
||||
@@ -241,6 +273,7 @@ class MeloTtsLexicon::Impl {
|
||||
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
|
||||
}
|
||||
|
||||
// For Chinese+English MeloTTS
|
||||
word2ids_["呣"] = word2ids_["母"];
|
||||
word2ids_["嗯"] = word2ids_["恩"];
|
||||
}
|
||||
@@ -268,6 +301,12 @@ MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
|
||||
debug)) {}
|
||||
|
||||
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
|
||||
const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data,
|
||||
bool debug)
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, meta_data, debug)) {}
|
||||
|
||||
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*unused_voice = ""*/) const {
|
||||
return impl_->ConvertTextToTokenIds(text);
|
||||
|
||||
@@ -22,6 +22,9 @@ class MeloTtsLexicon : public OfflineTtsFrontend {
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text,
|
||||
const std::string &unused_voice = "") const override;
|
||||
|
||||
@@ -349,6 +349,10 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
config_.model.vits.dict_dir, model_->GetMetaData(),
|
||||
config_.model.debug);
|
||||
} else if (meta_data.is_melo_tts && meta_data.language == "English") {
|
||||
frontend_ = std::make_unique<MeloTtsLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
model_->GetMetaData(), config_.model.debug);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
|
||||
frontend_ = std::make_unique<JiebaLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
|
||||
@@ -46,8 +46,10 @@ class OfflineTtsVitsModel::Impl {
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
|
||||
// For MeloTTS, we hardcode sid to the one contained in the meta data
|
||||
sid = meta_data_.speaker_id;
|
||||
if (meta_data_.num_speakers == 1) {
|
||||
// For MeloTTS, we hardcode sid to the one contained in the meta data
|
||||
sid = meta_data_.speaker_id;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -408,10 +408,10 @@ std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
|
||||
// For other versions, we may need to change it
|
||||
#if ORT_API_VERSION >= 12
|
||||
auto v = meta_data.LookupCustomMetadataMapAllocated(key, allocator);
|
||||
return v.get();
|
||||
return v ? v.get() : "";
|
||||
#else
|
||||
auto v = meta_data.LookupCustomMetadataMap(key, allocator);
|
||||
std::string ans = v;
|
||||
std::string ans = v ? v : "";
|
||||
allocator->Free(allocator, v);
|
||||
return ans;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user