Export Kokoro 1.0 to sherpa-onnx (#1788)
This commit is contained in:
212
scripts/kokoro/v1.0/test.py
Executable file
212
scripts/kokoro/v1.0/test.py
Executable file
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from misaki import zh
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease run\n"
|
||||
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
||||
)
|
||||
|
||||
|
||||
def show(filename):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.log_severity_level = 3
|
||||
sess = ort.InferenceSession(filename, session_opts)
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in sess.get_outputs():
|
||||
print(i)
|
||||
|
||||
|
||||
"""
|
||||
NodeArg(name='tokens', type='tensor(int64)', shape=[1, 'sequence_length'])
|
||||
NodeArg(name='style', type='tensor(float)', shape=[1, 256])
|
||||
NodeArg(name='speed', type='tensor(float)', shape=[1])
|
||||
-----
|
||||
NodeArg(name='audio', type='tensor(float)', shape=['audio_length'])
|
||||
"""
|
||||
|
||||
|
||||
def load_voices(speaker_names: List[str], dim: List[int], voices_bin: str):
|
||||
embedding = (
|
||||
np.fromfile(voices_bin, dtype="uint8")
|
||||
.view(np.float32)
|
||||
.reshape(len(speaker_names), *dim)
|
||||
)
|
||||
print("embedding.shape", embedding.shape)
|
||||
ans = dict()
|
||||
for i in range(len(speaker_names)):
|
||||
ans[speaker_names[i]] = embedding[i]
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def load_tokens(filename: str) -> Dict[str, int]:
|
||||
ans = dict()
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
fields = line.strip().split()
|
||||
if len(fields) == 2:
|
||||
token, idx = fields
|
||||
ans[token] = int(idx)
|
||||
else:
|
||||
assert len(fields) == 1, (len(fields), line)
|
||||
ans[" "] = int(fields[0])
|
||||
return ans
|
||||
|
||||
|
||||
def load_lexicon(filename: str) -> Dict[str, List[str]]:
|
||||
ans = dict()
|
||||
for lexicon in filename.split(","):
|
||||
print(lexicon)
|
||||
with open(lexicon, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
w, tokens = line.strip().split(" ", maxsplit=1)
|
||||
ans[w] = "".join(tokens.split())
|
||||
return ans
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, model_filename: str, tokens: str, lexicon: str, voices_bin: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.token2id = load_tokens(tokens)
|
||||
self.word2tokens = load_lexicon(lexicon)
|
||||
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
dim = list(map(int, meta["style_dim"].split(",")))
|
||||
speaker_names = meta["speaker_names"].split(",")
|
||||
self.voices = load_voices(
|
||||
speaker_names=speaker_names, dim=dim, voices_bin=voices_bin
|
||||
)
|
||||
self.sample_rate = int(meta["sample_rate"])
|
||||
print(list(self.voices.keys()))
|
||||
|
||||
self.sample_rate = 24000
|
||||
self.max_len = self.voices[next(iter(self.voices))].shape[0] - 1
|
||||
|
||||
def __call__(self, text: str, voice: str):
|
||||
punctuations = ';:,.!?-…()"“”'
|
||||
text = text.lower()
|
||||
g2p = zh.ZHG2P()
|
||||
|
||||
tokens = ""
|
||||
|
||||
for t in re.findall("[\u4E00-\u9FFF]+|[\u0000-\u007f]+", text):
|
||||
if ord(t[0]) < 0x7F:
|
||||
for w in t.split():
|
||||
while w:
|
||||
if w[0] in punctuations:
|
||||
tokens += w[0] + " "
|
||||
w = w[1:]
|
||||
continue
|
||||
|
||||
if w[-1] in punctuations:
|
||||
if w[:-1] in self.word2tokens:
|
||||
tokens += self.word2tokens[w[:-1]]
|
||||
tokens += w[-1]
|
||||
else:
|
||||
if w in self.word2tokens:
|
||||
tokens += self.word2tokens[w]
|
||||
else:
|
||||
print(f"Use espeak-ng for word {w}")
|
||||
tokens += "".join(phonemize_espeak(w, "en-us")[0])
|
||||
|
||||
tokens += " "
|
||||
break
|
||||
else:
|
||||
# Chinese
|
||||
for w in jieba.cut(t):
|
||||
if w in self.word2tokens:
|
||||
tokens += self.word2tokens[w]
|
||||
else:
|
||||
for i in w:
|
||||
if i in self.word2tokens:
|
||||
tokens += self.word2tokens[i]
|
||||
else:
|
||||
print(f"skip {i}")
|
||||
|
||||
token_ids = [self.token2id[i] for i in tokens]
|
||||
token_ids = token_ids[: self.max_len]
|
||||
|
||||
style = self.voices[voice][len(token_ids)]
|
||||
|
||||
token_ids = [0, *token_ids, 0]
|
||||
token_ids = np.array([token_ids], dtype=np.int64)
|
||||
|
||||
speed = np.array([1.0], dtype=np.float32)
|
||||
|
||||
audio = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: token_ids,
|
||||
self.model.get_inputs()[1].name: style,
|
||||
self.model.get_inputs()[2].name: speed,
|
||||
},
|
||||
)[0]
|
||||
return audio
|
||||
|
||||
|
||||
def main():
|
||||
m = OnnxModel(
|
||||
model_filename="./kokoro.onnx",
|
||||
tokens="./tokens.txt",
|
||||
lexicon="./lexicon-gb-en.txt,./lexicon-zh.txt",
|
||||
voices_bin="./voices.bin",
|
||||
)
|
||||
text = "来听一听, 这个是什么口音? How are you doing? Are you ok? Thank you! 你觉得中英文说得如何呢?"
|
||||
|
||||
text = text.lower()
|
||||
|
||||
voice = "bf_alice"
|
||||
start = time.time()
|
||||
audio = m(text, voice=voice)
|
||||
end = time.time()
|
||||
|
||||
elapsed_seconds = end - start
|
||||
audio_duration = len(audio) / m.sample_rate
|
||||
real_time_factor = elapsed_seconds / audio_duration
|
||||
|
||||
filename = f"kokoro_v1.0_{voice}_zh_en.wav"
|
||||
sf.write(
|
||||
filename,
|
||||
audio,
|
||||
samplerate=m.sample_rate,
|
||||
subtype="PCM_16",
|
||||
)
|
||||
print(f" Saved to {filename}")
|
||||
print(f" Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
print(f" Audio duration in seconds: {audio_duration:.3f}")
|
||||
print(f" RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user