Add TTS with VITS (#360)
This commit is contained in:
@@ -86,6 +86,15 @@ set(sources
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
list(APPEND sources
|
||||
lexicon.cc
|
||||
offline-tts-impl.cc
|
||||
offline-tts-model-config.cc
|
||||
offline-tts-vits-model-config.cc
|
||||
offline-tts-vits-model.cc
|
||||
offline-tts.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_CHECK)
|
||||
list(APPEND sources log.cc)
|
||||
endif()
|
||||
@@ -135,23 +144,31 @@ endif()
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
|
||||
|
||||
target_link_libraries(sherpa-onnx sherpa-onnx-core)
|
||||
target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
|
||||
target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core)
|
||||
target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
|
||||
|
||||
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
|
||||
|
||||
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
|
||||
|
||||
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib")
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
|
||||
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
|
||||
target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
|
||||
target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -170,6 +187,7 @@ install(
|
||||
sherpa-onnx
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-tts
|
||||
DESTINATION
|
||||
bin
|
||||
)
|
||||
|
||||
157
sherpa-onnx/csrc/lexicon.cc
Normal file
157
sherpa-onnx/csrc/lexicon.cc
Normal file
@@ -0,0 +1,157 @@
|
||||
// sherpa-onnx/csrc/lexicon.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void ToLowerCase(std::string *in_out) {
|
||||
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
}
|
||||
|
||||
// Note: We don't use SymbolTable here since tokens may contain a blank
|
||||
// in the first column
|
||||
static std::unordered_map<std::string, int32_t> ReadTokens(
|
||||
const std::string &tokens) {
|
||||
std::unordered_map<std::string, int32_t> token2id;
|
||||
|
||||
std::ifstream is(tokens);
|
||||
std::string line;
|
||||
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (std::getline(is, line)) {
|
||||
std::istringstream iss(line);
|
||||
iss >> sym;
|
||||
if (iss.eof()) {
|
||||
id = atoi(sym.c_str());
|
||||
sym = " ";
|
||||
} else {
|
||||
iss >> id;
|
||||
}
|
||||
|
||||
if (!iss.eof()) {
|
||||
SHERPA_ONNX_LOGE("Error: %s", line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
#if 0
|
||||
if (token2id.count(sym)) {
|
||||
SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d",
|
||||
sym.c_str(), line.c_str(), token2id.at(sym));
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
token2id.insert({sym, id});
|
||||
}
|
||||
|
||||
return token2id;
|
||||
}
|
||||
|
||||
static std::vector<int32_t> ConvertTokensToIds(
|
||||
const std::unordered_map<std::string, int32_t> &token2id,
|
||||
const std::vector<std::string> &tokens) {
|
||||
std::vector<int32_t> ids;
|
||||
ids.reserve(tokens.size());
|
||||
for (const auto &s : tokens) {
|
||||
if (!token2id.count(s)) {
|
||||
return {};
|
||||
}
|
||||
int32_t id = token2id.at(s);
|
||||
ids.push_back(id);
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &punctuations) {
|
||||
token2id_ = ReadTokens(tokens);
|
||||
std::ifstream is(lexicon);
|
||||
|
||||
std::string word;
|
||||
std::vector<std::string> token_list;
|
||||
std::string line;
|
||||
std::string phone;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::istringstream iss(line);
|
||||
|
||||
token_list.clear();
|
||||
|
||||
iss >> word;
|
||||
ToLowerCase(&word);
|
||||
|
||||
if (word2ids_.count(word)) {
|
||||
SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
while (iss >> phone) {
|
||||
token_list.push_back(std::move(phone));
|
||||
}
|
||||
|
||||
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, token_list);
|
||||
if (ids.empty()) {
|
||||
continue;
|
||||
}
|
||||
word2ids_.insert({std::move(word), std::move(ids)});
|
||||
}
|
||||
|
||||
// process punctuations
|
||||
std::vector<std::string> punctuation_list;
|
||||
SplitStringToVector(punctuations, " ", false, &punctuation_list);
|
||||
for (auto &s : punctuation_list) {
|
||||
punctuations_.insert(std::move(s));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
|
||||
std::vector<std::string> words;
|
||||
SplitStringToVector(text, " ", false, &words);
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
for (auto w : words) {
|
||||
std::vector<int64_t> prefix;
|
||||
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
|
||||
// if w begins with a punctuation
|
||||
prefix.push_back(token2id_.at(std::string(1, w[0])));
|
||||
w = std::string(w.begin() + 1, w.end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> suffix;
|
||||
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
|
||||
suffix.push_back(token2id_.at(std::string(1, w.back())));
|
||||
w = std::string(w.begin(), w.end() - 1);
|
||||
}
|
||||
|
||||
if (!word2ids_.count(w)) {
|
||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), prefix.begin(), prefix.end());
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
31
sherpa-onnx/csrc/lexicon.h
Normal file
31
sherpa-onnx/csrc/lexicon.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/lexicon.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_LEXICON_H_
|
||||
#define SHERPA_ONNX_CSRC_LEXICON_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class Lexicon {
|
||||
public:
|
||||
Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &punctuations);
|
||||
|
||||
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
|
||||
std::unordered_set<std::string> punctuations_;
|
||||
std::unordered_map<std::string, int32_t> token2id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_LEXICON_H_
|
||||
19
sherpa-onnx/csrc/offline-tts-impl.cc
Normal file
19
sherpa-onnx/csrc/offline-tts-impl.cc
Normal file
@@ -0,0 +1,19 @@
|
||||
// sherpa-onnx/csrc/offline-tts-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
|
||||
const OfflineTtsConfig &config) {
|
||||
// TODO(fangjun): Support other types
|
||||
return std::make_unique<OfflineTtsVitsImpl>(config);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
26
sherpa-onnx/csrc/offline-tts-impl.h
Normal file
26
sherpa-onnx/csrc/offline-tts-impl.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// sherpa-onnx/csrc/offline-tts-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsImpl {
|
||||
public:
|
||||
virtual ~OfflineTtsImpl() = default;
|
||||
|
||||
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
|
||||
|
||||
virtual GeneratedAudio Generate(const std::string &text) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_
|
||||
45
sherpa-onnx/csrc/offline-tts-model-config.cc
Normal file
45
sherpa-onnx/csrc/offline-tts-model-config.cc
Normal file
@@ -0,0 +1,45 @@
|
||||
// sherpa-onnx/csrc/offline-tts-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineTtsModelConfig::Register(ParseOptions *po) {
|
||||
vits.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflineTtsModelConfig::Validate() const {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
return vits.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineTtsModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineTtsModelConfig(";
|
||||
os << "vits=" << vits.ToString() << ", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
40
sherpa-onnx/csrc/offline-tts-model-config.h
Normal file
40
sherpa-onnx/csrc/offline-tts-model-config.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-tts-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTtsModelConfig {
|
||||
OfflineTtsVitsModelConfig vits;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflineTtsModelConfig() = default;
|
||||
|
||||
OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits,
|
||||
int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: vits(vits),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_
|
||||
77
sherpa-onnx/csrc/offline-tts-vits-impl.h
Normal file
77
sherpa-onnx/csrc/offline-tts-vits-impl.h
Normal file
@@ -0,0 +1,77 @@
|
||||
// sherpa-onnx/csrc/offline-tts-vits-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
public:
|
||||
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
|
||||
: model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations()) {
|
||||
SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str());
|
||||
}
|
||||
|
||||
GeneratedAudio Generate(const std::string &text) const override {
|
||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||
if (x.empty()) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
if (model_->AddBlank()) {
|
||||
std::vector<int64_t> buffer(x.size() * 2 + 1);
|
||||
int32_t i = 1;
|
||||
for (auto k : x) {
|
||||
buffer[i] = k;
|
||||
i += 2;
|
||||
}
|
||||
x = std::move(buffer);
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> x_shape = {1, static_cast<int32_t>(x.size())};
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor));
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int64_t total = 1;
|
||||
// The output shape may be (1, 1, total) or (1, total) or (total,)
|
||||
for (auto i : audio_shape) {
|
||||
total *= i;
|
||||
}
|
||||
|
||||
const float *p = audio.GetTensorData<float>();
|
||||
|
||||
GeneratedAudio ans;
|
||||
ans.sample_rate = model_->SampleRate();
|
||||
ans.samples = std::vector<float>(p, p + total);
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineTtsVitsModel> model_;
|
||||
Lexicon lexicon_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_
|
||||
63
sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Normal file
63
sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
// sherpa-onnx/csrc/offline-tts-vits-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineTtsVitsModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("vits-model", &model, "Path to VITS model");
|
||||
po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models");
|
||||
po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models");
|
||||
}
|
||||
|
||||
bool OfflineTtsVitsModelConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --vits-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lexicon.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --vits-lexicon");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(lexicon)) {
|
||||
SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tokens.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --vits-tokens");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineTtsVitsModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineTtsVitsModelConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "lexicon=\"" << lexicon << "\", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
34
sherpa-onnx/csrc/offline-tts-vits-model-config.h
Normal file
34
sherpa-onnx/csrc/offline-tts-vits-model-config.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-tts-vits-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTtsVitsModelConfig {
|
||||
std::string model;
|
||||
std::string lexicon;
|
||||
std::string tokens;
|
||||
|
||||
OfflineTtsVitsModelConfig() = default;
|
||||
|
||||
OfflineTtsVitsModelConfig(const std::string &model,
|
||||
const std::string &lexicon,
|
||||
const std::string &tokens)
|
||||
: model(model), lexicon(lexicon), tokens(tokens) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_
|
||||
135
sherpa-onnx/csrc/offline-tts-vits-model.cc
Normal file
135
sherpa-onnx/csrc/offline-tts-vits-model.cc
Normal file
@@ -0,0 +1,135 @@
|
||||
// sherpa-onnx/csrc/offline-tts-vits-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsVitsModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineTtsModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config.vits.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (x_shape[0] != 1) {
|
||||
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
|
||||
static_cast<int32_t>(x_shape[0]));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int64_t len = x_shape[1];
|
||||
int64_t len_shape = 1;
|
||||
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
int64_t scale_shape = 1;
|
||||
float noise_scale = 1;
|
||||
float length_scale = 1;
|
||||
float noise_scale_w = 1;
|
||||
|
||||
Ort::Value noise_scale_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
|
||||
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &length_scale, 1, &scale_shape, 1);
|
||||
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &noise_scale_w, 1, &scale_shape, 1);
|
||||
|
||||
std::array<Ort::Value, 5> inputs = {
|
||||
std::move(x), std::move(x_length), std::move(noise_scale_tensor),
|
||||
std::move(length_scale_tensor), std::move(noise_scale_w_tensor)};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
int32_t SampleRate() const { return sample_rate_; }
|
||||
|
||||
bool AddBlank() const { return add_blank_; }
|
||||
|
||||
std::string Punctuations() const { return punctuations_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---vits model---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineTtsModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t sample_rate_;
|
||||
int32_t add_blank_;
|
||||
std::string punctuations_;
|
||||
};
|
||||
|
||||
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
|
||||
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) {
|
||||
return impl_->Run(std::move(x));
|
||||
}
|
||||
|
||||
int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }
|
||||
|
||||
bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }
|
||||
|
||||
std::string OfflineTtsVitsModel::Punctuations() const {
|
||||
return impl_->Punctuations();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
45
sherpa-onnx/csrc/offline-tts-vits-model.h
Normal file
45
sherpa-onnx/csrc/offline-tts-vits-model.h
Normal file
@@ -0,0 +1,45 @@
|
||||
// sherpa-onnx/csrc/offline-tts-vits-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTtsVitsModel {
|
||||
public:
|
||||
~OfflineTtsVitsModel();
|
||||
|
||||
explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config);
|
||||
|
||||
/** Run the model.
|
||||
*
|
||||
* @param x A int64 tensor of shape (1, num_tokens)
|
||||
* @return Return a float32 tensor containing audio samples. You can flatten
|
||||
* it to a 1-D tensor.
|
||||
*/
|
||||
Ort::Value Run(Ort::Value x);
|
||||
|
||||
// Sample rate of the generated audio
|
||||
int32_t SampleRate() const;
|
||||
|
||||
// true to insert a blank between each token
|
||||
bool AddBlank() const;
|
||||
|
||||
std::string Punctuations() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_
|
||||
35
sherpa-onnx/csrc/offline-tts.cc
Normal file
35
sherpa-onnx/csrc/offline-tts.cc
Normal file
@@ -0,0 +1,35 @@
|
||||
// sherpa-onnx/csrc/offline-tts.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); }
|
||||
|
||||
bool OfflineTtsConfig::Validate() const { return model.Validate(); }
|
||||
|
||||
std::string OfflineTtsConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineTtsConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
OfflineTts::OfflineTts(const OfflineTtsConfig &config)
|
||||
: impl_(OfflineTtsImpl::Create(config)) {}
|
||||
|
||||
OfflineTts::~OfflineTts() = default;
|
||||
|
||||
GeneratedAudio OfflineTts::Generate(const std::string &text) const {
|
||||
return impl_->Generate(text);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
50
sherpa-onnx/csrc/offline-tts.h
Normal file
50
sherpa-onnx/csrc/offline-tts.h
Normal file
@@ -0,0 +1,50 @@
|
||||
// sherpa-onnx/csrc/offline-tts.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTtsConfig {
|
||||
OfflineTtsModelConfig model;
|
||||
|
||||
OfflineTtsConfig() = default;
|
||||
explicit OfflineTtsConfig(const OfflineTtsModelConfig &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
struct GeneratedAudio {
|
||||
std::vector<float> samples;
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
class OfflineTtsImpl;
|
||||
|
||||
class OfflineTts {
|
||||
public:
|
||||
~OfflineTts();
|
||||
explicit OfflineTts(const OfflineTtsConfig &config);
|
||||
// @param text A string containing words separated by spaces
|
||||
GeneratedAudio Generate(const std::string &text) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineTtsImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_
|
||||
@@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
@@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
57
sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
Normal file
57
sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
Normal file
@@ -0,0 +1,57 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Offline text-to-speech with sherpa-onnx
|
||||
|
||||
./bin/sherpa-onnx-offline-tts \
|
||||
--vits-model /path/to/model.onnx \
|
||||
--vits-lexicon /path/to/lexicon.txt \
|
||||
--vits-tokens /path/to/tokens.txt
|
||||
'some text within single quotes'
|
||||
|
||||
It will generate a file test.wav.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflineTtsConfig config;
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() == 0) {
|
||||
fprintf(stderr, "Error: Please provide the text to generate audio.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (po.NumArgs() > 1) {
|
||||
fprintf(stderr,
|
||||
"Error: Accept only one positional argument. Please use single "
|
||||
"quotes to wrap your text\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineTts tts(config);
|
||||
auto audio = tts.Generate(po.GetArg(1));
|
||||
|
||||
std::ofstream os("t.pcm", std::ios::binary);
|
||||
os.write(reinterpret_cast<const char *>(audio.samples.data()),
|
||||
sizeof(float) * audio.samples.size());
|
||||
|
||||
// sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user