Support VITS VCTK models (#367)
* Support VITS VCTK models * Release v1.8.1
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
project(sherpa-onnx)
|
||||
|
||||
set(SHERPA_ONNX_VERSION "1.8.0")
|
||||
set(SHERPA_ONNX_VERSION "1.8.1")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
|
||||
@@ -58,6 +58,16 @@ def get_args():
|
||||
help="Path to save generated wave",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sid",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Speaker ID. Used only for multi-speaker models, e.g.
|
||||
models trained using the VCTK dataset. Not used for single-speaker
|
||||
models, e.g., models trained using the LJ speech dataset.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
@@ -105,7 +115,7 @@ def main():
|
||||
)
|
||||
)
|
||||
tts = sherpa_onnx.OfflineTts(tts_config)
|
||||
audio = tts.generate(args.text)
|
||||
audio = tts.generate(args.text, sid=args.sid)
|
||||
sf.write(
|
||||
args.output_filename,
|
||||
audio.samples,
|
||||
|
||||
1
scripts/vits/.gitignore
vendored
1
scripts/vits/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
tokens-ljs.txt
|
||||
tokens-vctk.txt
|
||||
|
||||
@@ -191,6 +191,7 @@ def main():
|
||||
"comment": "ljspeech",
|
||||
"language": "English",
|
||||
"add_blank": int(hps.data.add_blank),
|
||||
"n_speakers": int(hps.data.n_speakers),
|
||||
"sample_rate": hps.data.sampling_rate,
|
||||
"punctuation": " ".join(list(_punctuation)),
|
||||
}
|
||||
|
||||
222
scripts/vits/export-onnx-vctk.py
Executable file
222
scripts/vits/export-onnx-vctk.py
Executable file
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script converts vits models trained using the VCTK dataset.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Download vits
|
||||
|
||||
cd /Users/fangjun/open-source
|
||||
git clone https://github.com/jaywalnut310/vits
|
||||
|
||||
(2) Download pre-trained models from
|
||||
https://huggingface.co/csukuangfj/vits-vctk/tree/main
|
||||
|
||||
wget https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
|
||||
|
||||
(3) Run this file
|
||||
|
||||
./export-onnx-vctk.py \
|
||||
--config ~/open-source//vits/configs/vctk_base.json \
|
||||
--checkpoint ~/open-source/icefall-models/vits-vctk/pretrained_vctk.pth
|
||||
|
||||
It will generate the following two files:
|
||||
|
||||
$ ls -lh *.onnx
|
||||
-rw-r--r-- 1 fangjun staff 37M Oct 16 10:57 vits-vctk.int8.onnx
|
||||
-rw-r--r-- 1 fangjun staff 116M Oct 16 10:57 vits-vctk.onnx
|
||||
"""
|
||||
import sys
|
||||
|
||||
# Please change this line to point to the vits directory.
|
||||
# You can download vits from
|
||||
# https://github.com/jaywalnut310/vits
|
||||
sys.path.insert(0, "/Users/fangjun/open-source/vits") # noqa
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
import commons
|
||||
import onnx
|
||||
import torch
|
||||
import utils
|
||||
from models import SynthesizerTrn
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from text import text_to_sequence
|
||||
from text.symbols import symbols
|
||||
from text.symbols import _punctuation
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to vctk_base.json.
|
||||
You can find it at
|
||||
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to the checkpoint file.
|
||||
You can find it at
|
||||
https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OnnxModel(torch.nn.Module):
|
||||
def __init__(self, model: SynthesizerTrn):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
x_lengths,
|
||||
noise_scale=1,
|
||||
length_scale=1,
|
||||
noise_scale_w=1.0,
|
||||
sid=0,
|
||||
max_len=None,
|
||||
):
|
||||
return self.model.infer(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
sid=sid,
|
||||
noise_scale=noise_scale,
|
||||
length_scale=length_scale,
|
||||
noise_scale_w=noise_scale_w,
|
||||
max_len=max_len,
|
||||
)[0]
|
||||
|
||||
|
||||
def get_text(text, hps):
|
||||
text_norm = text_to_sequence(text, hps.data.text_cleaners)
|
||||
if hps.data.add_blank:
|
||||
text_norm = commons.intersperse(text_norm, 0)
|
||||
text_norm = torch.LongTensor(text_norm)
|
||||
return text_norm
|
||||
|
||||
|
||||
def check_args(args):
|
||||
assert Path(args.config).is_file(), args.config
|
||||
assert Path(args.checkpoint).is_file(), args.checkpoint
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def generate_tokens():
|
||||
with open("tokens-vctk.txt", "w", encoding="utf-8") as f:
|
||||
for i, s in enumerate(symbols):
|
||||
f.write(f"{s} {i}\n")
|
||||
print("Generated tokens-vctk.txt")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
check_args(args)
|
||||
|
||||
generate_tokens()
|
||||
|
||||
hps = utils.get_hparams_from_file(args.config)
|
||||
|
||||
net_g = SynthesizerTrn(
|
||||
len(symbols),
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
_ = net_g.eval()
|
||||
|
||||
_ = utils.load_checkpoint(args.checkpoint, net_g, None)
|
||||
|
||||
x = get_text("Liliana is the most beautiful assistant", hps)
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
length_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_w = torch.tensor([1], dtype=torch.float32)
|
||||
sid = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
model = OnnxModel(net_g)
|
||||
|
||||
opset_version = 13
|
||||
|
||||
filename = "vits-vctk.onnx"
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, x_length, noise_scale, length_scale, noise_scale_w, sid),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"x",
|
||||
"x_length",
|
||||
"noise_scale",
|
||||
"length_scale",
|
||||
"noise_scale_w",
|
||||
"sid",
|
||||
],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
|
||||
"x_length": {0: "N"},
|
||||
"y": {0: "N", 2: "L"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"model_type": "vits",
|
||||
"comment": "vctk",
|
||||
"language": "English",
|
||||
"add_blank": int(hps.data.add_blank),
|
||||
"n_speakers": int(hps.data.n_speakers),
|
||||
"sample_rate": hps.data.sampling_rate,
|
||||
"punctuation": " ".join(list(_punctuation)),
|
||||
}
|
||||
print("meta_data", meta_data)
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
print("Generate int8 quantization models")
|
||||
|
||||
filename_int8 = "vits-vctk.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
print(f"Saved to {filename} and {filename_int8}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,7 +18,8 @@ class OfflineTtsImpl {
|
||||
|
||||
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
|
||||
|
||||
virtual GeneratedAudio Generate(const std::string &text) const = 0;
|
||||
virtual GeneratedAudio Generate(const std::string &text,
|
||||
int64_t sid = 0) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations()) {}
|
||||
|
||||
GeneratedAudio Generate(const std::string &text) const override {
|
||||
GeneratedAudio Generate(const std::string &text,
|
||||
int64_t sid = 0) const override {
|
||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||
if (x.empty()) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
@@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor));
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor), sid);
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("vits-model", &model, "Path to VITS model");
|
||||
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
|
||||
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
|
||||
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
|
||||
po->Register("vits-noise-scale-w", &noise_scale_w,
|
||||
"noise_scale_w for VITS models");
|
||||
po->Register("vits-length-scale", &length_scale,
|
||||
"length_scale for VITS models");
|
||||
}
|
||||
|
||||
bool OfflineTtsVitsModelConfig::Validate() const {
|
||||
@@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
|
||||
os << "OfflineTtsVitsModelConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "lexicon=\"" << lexicon << "\", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "noise_scale=" << noise_scale << ", ";
|
||||
os << "noise_scale_w=" << noise_scale_w << ", ";
|
||||
os << "length_scale=" << length_scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
|
||||
std::string lexicon;
|
||||
std::string tokens;
|
||||
|
||||
float noise_scale = 0.667;
|
||||
float noise_scale_w = 0.8;
|
||||
float length_scale = 1;
|
||||
|
||||
// used only for multi-speaker models, e.g, vctk speech dataset.
|
||||
// Not applicable for single-speaker models, e.g., ljspeech dataset
|
||||
|
||||
OfflineTtsVitsModelConfig() = default;
|
||||
|
||||
OfflineTtsVitsModelConfig(const std::string &model,
|
||||
const std::string &lexicon,
|
||||
const std::string &tokens)
|
||||
: model(model), lexicon(lexicon), tokens(tokens) {}
|
||||
const std::string &tokens,
|
||||
float noise_scale = 0.667,
|
||||
float noise_scale_w = 0.8, float length_scale = 1)
|
||||
: model(model),
|
||||
lexicon(lexicon),
|
||||
tokens(tokens),
|
||||
noise_scale(noise_scale),
|
||||
noise_scale_w(noise_scale_w),
|
||||
length_scale(length_scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x) {
|
||||
Ort::Value Run(Ort::Value x, int64_t sid) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl {
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
int64_t scale_shape = 1;
|
||||
float noise_scale = 1;
|
||||
float length_scale = 1;
|
||||
float noise_scale_w = 1;
|
||||
float noise_scale = config_.vits.noise_scale;
|
||||
float length_scale = config_.vits.length_scale;
|
||||
float noise_scale_w = config_.vits.noise_scale_w;
|
||||
|
||||
Ort::Value noise_scale_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &length_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &noise_scale_w, 1, &scale_shape, 1);
|
||||
|
||||
std::array<Ort::Value, 5> inputs = {
|
||||
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
|
||||
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
|
||||
Ort::Value sid_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(6);
|
||||
inputs.push_back(std::move(x));
|
||||
inputs.push_back(std::move(x_length));
|
||||
inputs.push_back(std::move(noise_scale_tensor));
|
||||
inputs.push_back(std::move(length_scale_tensor));
|
||||
inputs.push_back(std::move(noise_scale_w_tensor));
|
||||
|
||||
if (input_names_.size() == 6 && input_names_.back() == "sid") {
|
||||
inputs.push_back(std::move(sid_tensor));
|
||||
}
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
@@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
||||
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
||||
}
|
||||
|
||||
@@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
|
||||
int32_t sample_rate_;
|
||||
int32_t add_blank_;
|
||||
int32_t n_speakers_;
|
||||
std::string punctuations_;
|
||||
};
|
||||
|
||||
@@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
|
||||
|
||||
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
|
||||
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
|
||||
return impl_->Run(std::move(x));
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) {
|
||||
return impl_->Run(std::move(x), sid);
|
||||
}
|
||||
|
||||
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
|
||||
|
||||
@@ -22,10 +22,14 @@ class OfflineTtsVitsModel {
|
||||
/** Run the model.
|
||||
*
|
||||
* @param x A int64 tensor of shape (1, num_tokens)
|
||||
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
|
||||
// trained using the VCTK dataset. It is not used for
|
||||
// single-speaker models, e.g., models trained using the ljspeech
|
||||
// dataset.
|
||||
* @return Return a float32 tensor containing audio samples. You can flatten
|
||||
* it to a 1-D tensor.
|
||||
*/
|
||||
Ort::Value Run(Ort::Value x);
|
||||
Ort::Value Run(Ort::Value x, int64_t sid = 0);
|
||||
|
||||
// Sample rate of the generated audio
|
||||
int32_t SampleRate() const;
|
||||
|
||||
@@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config)
|
||||
|
||||
OfflineTts::~OfflineTts() = default;
|
||||
|
||||
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
|
||||
return impl_->Generate(text);
|
||||
GeneratedAudio OfflineTts::Generate(const std::string &text,
|
||||
int64_t sid /*=0*/) const {
|
||||
return impl_->Generate(text, sid);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -39,7 +39,11 @@ class OfflineTts {
|
||||
~OfflineTts();
|
||||
explicit OfflineTts(const OfflineTtsConfig &config);
|
||||
// @param text A string containing words separated by spaces
|
||||
GeneratedAudio Generate(const std::string &text) const;
|
||||
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
|
||||
// trained using the VCTK dataset. It is not used for
|
||||
// single-speaker models, e.g., models trained using the ljspeech
|
||||
// dataset.
|
||||
GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineTtsImpl> impl_;
|
||||
|
||||
@@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) {
|
||||
Offline text-to-speech with sherpa-onnx
|
||||
|
||||
./bin/sherpa-onnx-offline-tts \
|
||||
--vits-model /path/to/model.onnx \
|
||||
--vits-lexicon /path/to/lexicon.txt \
|
||||
--vits-tokens /path/to/tokens.txt
|
||||
--output-filename ./generated.wav \
|
||||
'some text within single quotes'
|
||||
--vits-model=/path/to/model.onnx \
|
||||
--vits-lexicon=/path/to/lexicon.txt \
|
||||
--vits-tokens=/path/to/tokens.txt \
|
||||
--sid=0 \
|
||||
--output-filename=./generated.wav \
|
||||
'some text within single quotes on linux/macos or use double quotes on windows'
|
||||
|
||||
It will generate a file ./generated.wav as specified by --output-filename.
|
||||
|
||||
@@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
--vits-model=./vits-ljs.onnx \
|
||||
--vits-lexicon=./lexicon.txt \
|
||||
--vits-tokens=./tokens.txt \
|
||||
--sid=0 \
|
||||
--output-filename=./generated.wav \
|
||||
'liliana, the most beautiful and lovely assistant of our team!'
|
||||
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
|
||||
or detailes.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
std::string output_filename = "./generated.wav";
|
||||
int32_t sid = 0;
|
||||
|
||||
po.Register("output-filename", &output_filename,
|
||||
"Path to save the generated audio");
|
||||
|
||||
po.Register("sid", &sid,
|
||||
"Speaker ID. Used only for multi-speaker models, e.g., models "
|
||||
"trained using the VCTK dataset. Not used for single-speaker "
|
||||
"models, e.g., models trained using the LJSpeech dataset");
|
||||
|
||||
sherpa_onnx::OfflineTtsConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
@@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineTts tts(config);
|
||||
auto audio = tts.Generate(po.GetArg(1));
|
||||
auto audio = tts.Generate(po.GetArg(1), sid);
|
||||
|
||||
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
|
||||
audio.samples.data(), audio.samples.size());
|
||||
@@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
|
||||
fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(),
|
||||
sid);
|
||||
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -16,11 +16,16 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &>(),
|
||||
py::arg("model"), py::arg("lexicon"), py::arg("tokens"))
|
||||
const std::string &, float, float, float>(),
|
||||
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
|
||||
py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8,
|
||||
py::arg("length_scale") = 1.0)
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("lexicon", &PyClass::lexicon)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("noise_scale", &PyClass::noise_scale)
|
||||
.def_readwrite("noise_scale_w", &PyClass::noise_scale_w)
|
||||
.def_readwrite("length_scale", &PyClass::length_scale)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ void PybindOfflineTts(py::module *m) {
|
||||
using PyClass = OfflineTts;
|
||||
py::class_<PyClass>(*m, "OfflineTts")
|
||||
.def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
|
||||
.def("generate", &PyClass::Generate);
|
||||
.def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user