Support English for MeloTTS models. (#1134)

This commit is contained in:
Fangjun Kuang
2024-07-15 19:49:22 +08:00
committed by GitHub
parent fa07bbc176
commit 95485411fa
5 changed files with 99 additions and 39 deletions

View File

@@ -6,9 +6,13 @@ import torch
from melo.api import TTS
from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from melo.text.english import eng_dict, refine_syllables
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
from melo.text.symbols import language_tone_start_map
for k, v in pinyin_to_symbol_map.items():
if isinstance(v, list):
break
pinyin_to_symbol_map[k] = v.split()
@@ -79,6 +83,16 @@ def generate_lexicon():
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
with open("lexicon.txt", "w", encoding="utf-8") as f:
for word in eng_dict:
phones, tones = refine_syllables(eng_dict[word])
tones = [t + language_tone_start_map["EN"] for t in tones]
tones = [str(t) for t in tones]
phones = " ".join(phones)
tones = " ".join(tones)
f.write(f"{word.lower()} {phones} {tones}\n")
for key in word_dict:
if not (0x4E00 <= key <= 0x9FA5):
continue
@@ -125,15 +139,13 @@ class ModelWrapper(torch.nn.Module):
def __init__(self, model: "SynthesizerTrn"):
super().__init__()
self.model = model
self.lang_id = language_id_map[model.language]
def forward(
self,
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
@@ -147,7 +159,11 @@ class ModelWrapper(torch.nn.Module):
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
return self.model.infer(
bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
lang_id = torch.zeros_like(x)
lang_id[:, 1::2] = self.lang_id
return self.model.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
@@ -169,7 +185,7 @@ def main():
generate_tokens(model.hps["symbols"])
torch_model = ModelWrapper(model.model)
torch_model = ModelWrapper(model)
opset_version = 13
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
@@ -177,19 +193,13 @@ def main():
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
sid = torch.tensor([1], dtype=torch.int64)
tones = torch.zeros_like(x)
lang_id = torch.ones_like(x)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang_id = lang_id.unsqueeze(0)
bert = bert.unsqueeze(0)
ja_bert = ja_bert.unsqueeze(0)
filename = "model.onnx"
@@ -199,9 +209,6 @@ def main():
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
@@ -213,9 +220,6 @@ def main():
"x",
"x_lengths",
"tones",
"lang_id",
"bert",
"ja_bert",
"sid",
"noise_scale",
"length_scale",
@@ -226,9 +230,6 @@ def main():
"x": {0: "N", 1: "L"},
"x_lengths": {0: "N"},
"tones": {0: "N", 1: "L"},
"lang_id": {0: "N", 1: "L"},
"bert": {0: "N", 2: "L"},
"ja_bert": {0: "N", 2: "L"},
"y": {0: "N", 1: "S", 2: "T"},
},
)