Support VITS VCTK models (#367)
* Support VITS VCTK models * Release v1.8.1
This commit is contained in:
@@ -18,7 +18,8 @@ class OfflineTtsImpl {
|
||||
|
||||
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
|
||||
|
||||
virtual GeneratedAudio Generate(const std::string &text) const = 0;
|
||||
virtual GeneratedAudio Generate(const std::string &text,
|
||||
int64_t sid = 0) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -23,7 +23,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations()) {}
|
||||
|
||||
GeneratedAudio Generate(const std::string &text) const override {
|
||||
GeneratedAudio Generate(const std::string &text,
|
||||
int64_t sid = 0) const override {
|
||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||
if (x.empty()) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
@@ -47,7 +48,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor));
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor), sid);
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -13,6 +13,11 @@ void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("vits-model", &model, "Path to VITS model");
|
||||
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
|
||||
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
|
||||
po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models");
|
||||
po->Register("vits-noise-scale-w", &noise_scale_w,
|
||||
"noise_scale_w for VITS models");
|
||||
po->Register("vits-length-scale", &length_scale,
|
||||
"length_scale for VITS models");
|
||||
}
|
||||
|
||||
bool OfflineTtsVitsModelConfig::Validate() const {
|
||||
@@ -55,7 +60,10 @@ std::string OfflineTtsVitsModelConfig::ToString() const {
|
||||
os << "OfflineTtsVitsModelConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "lexicon=\"" << lexicon << "\", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "noise_scale=" << noise_scale << ", ";
|
||||
os << "noise_scale_w=" << noise_scale_w << ", ";
|
||||
os << "length_scale=" << length_scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -16,12 +16,26 @@ struct OfflineTtsVitsModelConfig {
|
||||
std::string lexicon;
|
||||
std::string tokens;
|
||||
|
||||
float noise_scale = 0.667;
|
||||
float noise_scale_w = 0.8;
|
||||
float length_scale = 1;
|
||||
|
||||
// used only for multi-speaker models, e.g, vctk speech dataset.
|
||||
// Not applicable for single-speaker models, e.g., ljspeech dataset
|
||||
|
||||
OfflineTtsVitsModelConfig() = default;
|
||||
|
||||
OfflineTtsVitsModelConfig(const std::string &model,
|
||||
const std::string &lexicon,
|
||||
const std::string &tokens)
|
||||
: model(model), lexicon(lexicon), tokens(tokens) {}
|
||||
const std::string &tokens,
|
||||
float noise_scale = 0.667,
|
||||
float noise_scale_w = 0.8, float length_scale = 1)
|
||||
: model(model),
|
||||
lexicon(lexicon),
|
||||
tokens(tokens),
|
||||
noise_scale(noise_scale),
|
||||
noise_scale_w(noise_scale_w),
|
||||
length_scale(length_scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -26,7 +26,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x) {
|
||||
Ort::Value Run(Ort::Value x, int64_t sid) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -44,20 +44,33 @@ class OfflineTtsVitsModel::Impl {
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
int64_t scale_shape = 1;
|
||||
float noise_scale = 1;
|
||||
float length_scale = 1;
|
||||
float noise_scale_w = 1;
|
||||
float noise_scale = config_.vits.noise_scale;
|
||||
float length_scale = config_.vits.length_scale;
|
||||
float noise_scale_w = config_.vits.noise_scale_w;
|
||||
|
||||
Ort::Value noise_scale_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &length_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &noise_scale_w, 1, &scale_shape, 1);
|
||||
|
||||
std::array<Ort::Value, 5> inputs = {
|
||||
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
|
||||
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
|
||||
Ort::Value sid_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(6);
|
||||
inputs.push_back(std::move(x));
|
||||
inputs.push_back(std::move(x_length));
|
||||
inputs.push_back(std::move(noise_scale_tensor));
|
||||
inputs.push_back(std::move(length_scale_tensor));
|
||||
inputs.push_back(std::move(noise_scale_w_tensor));
|
||||
|
||||
if (input_names_.size() == 6 && input_names_.back() == "sid") {
|
||||
inputs.push_back(std::move(sid_tensor));
|
||||
}
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
@@ -93,6 +106,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
||||
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
||||
}
|
||||
|
||||
@@ -112,6 +126,7 @@ class OfflineTtsVitsModel::Impl {
|
||||
|
||||
int32_t sample_rate_;
|
||||
int32_t add_blank_;
|
||||
int32_t n_speakers_;
|
||||
std::string punctuations_;
|
||||
};
|
||||
|
||||
@@ -120,8 +135,8 @@ OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
|
||||
|
||||
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
|
||||
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
|
||||
return impl_->Run(std::move(x));
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/) {
|
||||
return impl_->Run(std::move(x), sid);
|
||||
}
|
||||
|
||||
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
|
||||
|
||||
@@ -22,10 +22,14 @@ class OfflineTtsVitsModel {
|
||||
/** Run the model.
|
||||
*
|
||||
* @param x A int64 tensor of shape (1, num_tokens)
|
||||
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
|
||||
// trained using the VCTK dataset. It is not used for
|
||||
// single-speaker models, e.g., models trained using the ljspeech
|
||||
// dataset.
|
||||
* @return Return a float32 tensor containing audio samples. You can flatten
|
||||
* it to a 1-D tensor.
|
||||
*/
|
||||
Ort::Value Run(Ort::Value x);
|
||||
Ort::Value Run(Ort::Value x, int64_t sid = 0);
|
||||
|
||||
// Sample rate of the generated audio
|
||||
int32_t SampleRate() const;
|
||||
|
||||
@@ -28,8 +28,9 @@ OfflineTts::OfflineTts(const OfflineTtsConfig &config)
|
||||
|
||||
OfflineTts::~OfflineTts() = default;
|
||||
|
||||
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
|
||||
return impl_->Generate(text);
|
||||
GeneratedAudio OfflineTts::Generate(const std::string &text,
|
||||
int64_t sid /*=0*/) const {
|
||||
return impl_->Generate(text, sid);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -39,7 +39,11 @@ class OfflineTts {
|
||||
~OfflineTts();
|
||||
explicit OfflineTts(const OfflineTtsConfig &config);
|
||||
// @param text A string containing words separated by spaces
|
||||
GeneratedAudio Generate(const std::string &text) const;
|
||||
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
|
||||
// trained using the VCTK dataset. It is not used for
|
||||
// single-speaker models, e.g., models trained using the ljspeech
|
||||
// dataset.
|
||||
GeneratedAudio Generate(const std::string &text, int64_t sid = 0) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineTtsImpl> impl_;
|
||||
|
||||
@@ -13,11 +13,12 @@ int main(int32_t argc, char *argv[]) {
|
||||
Offline text-to-speech with sherpa-onnx
|
||||
|
||||
./bin/sherpa-onnx-offline-tts \
|
||||
--vits-model /path/to/model.onnx \
|
||||
--vits-lexicon /path/to/lexicon.txt \
|
||||
--vits-tokens /path/to/tokens.txt
|
||||
--output-filename ./generated.wav \
|
||||
'some text within single quotes'
|
||||
--vits-model=/path/to/model.onnx \
|
||||
--vits-lexicon=/path/to/lexicon.txt \
|
||||
--vits-tokens=/path/to/tokens.txt \
|
||||
--sid=0 \
|
||||
--output-filename=./generated.wav \
|
||||
'some text within single quotes on linux/macos or use double quotes on windows'
|
||||
|
||||
It will generate a file ./generated.wav as specified by --output-filename.
|
||||
|
||||
@@ -33,15 +34,27 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
--vits-model=./vits-ljs.onnx \
|
||||
--vits-lexicon=./lexicon.txt \
|
||||
--vits-tokens=./tokens.txt \
|
||||
--sid=0 \
|
||||
--output-filename=./generated.wav \
|
||||
'liliana, the most beautiful and lovely assistant of our team!'
|
||||
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/tts/index.html
|
||||
or detailes.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
std::string output_filename = "./generated.wav";
|
||||
int32_t sid = 0;
|
||||
|
||||
po.Register("output-filename", &output_filename,
|
||||
"Path to save the generated audio");
|
||||
|
||||
po.Register("sid", &sid,
|
||||
"Speaker ID. Used only for multi-speaker models, e.g., models "
|
||||
"trained using the VCTK dataset. Not used for single-speaker "
|
||||
"models, e.g., models trained using the LJSpeech dataset");
|
||||
|
||||
sherpa_onnx::OfflineTtsConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
@@ -67,7 +80,7 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineTts tts(config);
|
||||
auto audio = tts.Generate(po.GetArg(1));
|
||||
auto audio = tts.Generate(po.GetArg(1), sid);
|
||||
|
||||
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
|
||||
audio.samples.data(), audio.samples.size());
|
||||
@@ -76,7 +89,8 @@ wget https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "The text is: %s\n", po.GetArg(1).c_str());
|
||||
fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(),
|
||||
sid);
|
||||
fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str());
|
||||
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user