Support VITS VCTK models (#367)

* Support VITS VCTK models

* Release v1.8.1
This commit is contained in:
Fangjun Kuang
2023-10-16 17:22:30 +08:00
committed by GitHub
parent d01682d968
commit 9efe69720d
16 changed files with 332 additions and 31 deletions

View File

@@ -18,7 +18,8 @@ class OfflineTtsImpl {
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
virtual GeneratedAudio Generate(const std::string &text) const = 0;
virtual GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const = 0;
};
} // namespace sherpa_onnx

View File

@@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations()) {}
GeneratedAudio Generate(const std::string &text) const override {
GeneratedAudio Generate(const std::string &text,
int64_t sid = 0) const override {
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
if (x.empty()) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
@@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor));
Ort::Value audio = model_->Run(std::move(x_tensor), sid);
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();

View File

@@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
po->Register("vits-model", &model, "Path to VITS model");
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
po->Register("vits-noise-scale-w", &noise_scale_w,
"noise_scale_w for VITS models");
po->Register("vits-length-scale", &length_scale,
"length_scale for VITS models");
}
bool OfflineTtsVitsModelConfig::Validate() const {
@@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
os << "OfflineTtsVitsModelConfig(";
os << "model=\"" << model << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "tokens=\"" << tokens << "\")";
os << "tokens=\"" << tokens << "\", ";
os << "noise_scale=" << noise_scale << ", ";
os << "noise_scale_w=" << noise_scale_w << ", ";
os << "length_scale=" << length_scale << ")";
return os.str();
}

View File

@@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
std::string lexicon;
std::string tokens;
float noise_scale = 0.667;
float noise_scale_w = 0.8;
float length_scale = 1;
// used only for multi-speaker models, e.g, vctk speech dataset.
// Not applicable for single-speaker models, e.g., ljspeech dataset
OfflineTtsVitsModelConfig() = default;
OfflineTtsVitsModelConfig(const std::string &model,
const std::string &lexicon,
const std::string &tokens)
: model(model), lexicon(lexicon), tokens(tokens) {}
const std::string &tokens,
float noise_scale = 0.667,
float noise_scale_w = 0.8, float length_scale = 1)
: model(model),
lexicon(lexicon),
tokens(tokens),
noise_scale(noise_scale),
noise_scale_w(noise_scale_w),
length_scale(length_scale) {}
void Register(ParseOptions *po);
bool Validate() const;

View File

@@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl {
Init(buf.data(), buf.size());
}
Ort::Value Run(Ort::Value x) {
Ort::Value Run(Ort::Value x, int64_t sid) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl {
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
int64_t scale_shape = 1;
float noise_scale = 1;
float length_scale = 1;
float noise_scale_w = 1;
float noise_scale = config_.vits.noise_scale;
float length_scale = config_.vits.length_scale;
float noise_scale_w = config_.vits.noise_scale_w;
Ort::Value noise_scale_tensor =
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
memory_info, &length_scale, 1, &scale_shape, 1);
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
memory_info, &noise_scale_w, 1, &scale_shape, 1);
std::array<Ort::Value, 5> inputs = {
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
Ort::Value sid_tensor =
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
std::vector<Ort::Value> inputs;
inputs.reserve(6);
inputs.push_back(std::move(x));
inputs.push_back(std::move(x_length));
inputs.push_back(std::move(noise_scale_tensor));
inputs.push_back(std::move(length_scale_tensor));
inputs.push_back(std::move(noise_scale_w_tensor));
if (input_names_.size() == 6 && input_names_.back() == "sid") {
inputs.push_back(std::move(sid_tensor));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
@@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl {
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
}
@@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl {
int32_t sample_rate_;
int32_t add_blank_;
int32_t n_speakers_;
std::string punctuations_;
};
@@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
return impl_->Run(std::move(x));
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) {
return impl_->Run(std::move(x), sid);
}
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }

View File

@@ -22,10 +22,14 @@ class OfflineTtsVitsModel {
/** Run the model.
*
* @param x A int64 tensor of shape (1, num_tokens)
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
// trained using the VCTK dataset. It is not used for
// single-speaker models, e.g., models trained using the ljspeech
// dataset.
* @return Return a float32 tensor containing audio samples. You can flatten
* it to a 1-D tensor.
*/
Ort::Value Run(Ort::Value x);
Ort::Value Run(Ort::Value x, int64_t sid = 0);
// Sample rate of the generated audio
int32_t SampleRate() const;

View File

@@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config)
OfflineTts::~OfflineTts() = default;
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
return impl_->Generate(text);
GeneratedAudio OfflineTts::Generate(const std::string &text,
int64_t sid /*=0*/) const {
return impl_->Generate(text, sid);
}
} // namespace sherpa_onnx

View File

@@ -39,7 +39,11 @@ class OfflineTts {
~OfflineTts();
explicit OfflineTts(const OfflineTtsConfig &config);
// @param text A string containing words separated by spaces
GeneratedAudio Generate(const std::string &text) const;
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
// trained using the VCTK dataset. It is not used for
// single-speaker models, e.g., models trained using the ljspeech
// dataset.
GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const;
private:
std::unique_ptr<OfflineTtsImpl> impl_;

View File

@@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) {
Offline text-to-speech with sherpa-onnx
./bin/sherpa-onnx-offline-tts \
--vits-model /path/to/model.onnx \
--vits-lexicon /path/to/lexicon.txt \
--vits-tokens /path/to/tokens.txt
--output-filename ./generated.wav \
'some text within single quotes'
--vits-model=/path/to/model.onnx \
--vits-lexicon=/path/to/lexicon.txt \
--vits-tokens=/path/to/tokens.txt \
--sid=0 \
--output-filename=./generated.wav \
'some text within single quotes on linux/macos or use double quotes on windows'
It will generate a file ./generated.wav as specified by --output-filename.
@@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
--vits-model=./vits-ljs.onnx \
--vits-lexicon=./lexicon.txt \
--vits-tokens=./tokens.txt \
--sid=0 \
--output-filename=./generated.wav \
'liliana, the most beautiful and lovely assistant of our team!'
Please see
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
or detailes.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
std::string output_filename = "./generated.wav";
int32_t sid = 0;
po.Register("output-filename", &output_filename,
"Path to save the generated audio");
po.Register("sid", &sid,
"Speaker ID. Used only for multi-speaker models, e.g., models "
"trained using the VCTK dataset. Not used for single-speaker "
"models, e.g., models trained using the LJSpeech dataset");
sherpa_onnx::OfflineTtsConfig config;
config.Register(&po);
@@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
}
sherpa_onnx::OfflineTts tts(config);
auto audio = tts.Generate(po.GetArg(1));
auto audio = tts.Generate(po.GetArg(1), sid);
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
audio.samples.data(), audio.samples.size());
@@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
exit(EXIT_FAILURE);
}
fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(),
sid);
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
return 0;