diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e3f6ff73..d3538480 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -86,6 +86,15 @@ set(sources wave-reader.cc ) +list(APPEND sources + lexicon.cc + offline-tts-impl.cc + offline-tts-model-config.cc + offline-tts-vits-model-config.cc + offline-tts-vits-model.cc + offline-tts.cc +) + if(SHERPA_ONNX_ENABLE_CHECK) list(APPEND sources log.cc) endif() @@ -135,23 +144,31 @@ endif() add_executable(sherpa-onnx sherpa-onnx.cc) add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) +add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) target_link_libraries(sherpa-onnx sherpa-onnx-core) target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) target_link_libraries(sherpa-onnx-offline-parallel sherpa-onnx-core) +target_link_libraries(sherpa-onnx-offline-tts sherpa-onnx-core) if(NOT WIN32) target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + + target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + if(SHERPA_ONNX_ENABLE_PYTHON) target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") target_link_libraries(sherpa-onnx-offline-parallel "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + target_link_libraries(sherpa-onnx-offline-tts "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") endif() endif() @@ -170,6 +187,7 @@ install( sherpa-onnx sherpa-onnx-offline sherpa-onnx-offline-parallel + sherpa-onnx-offline-tts DESTINATION bin ) diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc new file mode 100644 index 00000000..310fefa7 --- /dev/null +++ b/sherpa-onnx/csrc/lexicon.cc @@ -0,0 +1,157 @@ +// sherpa-onnx/csrc/lexicon.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/lexicon.h" + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +static void ToLowerCase(std::string *in_out) { + std::transform(in_out->begin(), in_out->end(), in_out->begin(), + [](unsigned char c) { return std::tolower(c); }); +} + +// Note: We don't use SymbolTable here since tokens may contain a blank +// in the first column +static std::unordered_map ReadTokens( + const std::string &tokens) { + std::unordered_map token2id; + + std::ifstream is(tokens); + std::string line; + + std::string sym; + int32_t id; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error: %s", line.c_str()); + exit(-1); + } + +#if 0 + if (token2id.count(sym)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(sym)); + exit(-1); + } +#endif + token2id.insert({sym, id}); + } + + return token2id; +} + +static std::vector ConvertTokensToIds( + const std::unordered_map &token2id, + const std::vector &tokens) { + std::vector ids; + ids.reserve(tokens.size()); + for (const auto &s : tokens) { + if (!token2id.count(s)) { + return {}; + } + int32_t id = token2id.at(s); + ids.push_back(id); + } + + return ids; +} + +Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations) { + token2id_ = ReadTokens(tokens); + std::ifstream is(lexicon); + + std::string word; + std::vector token_list; + std::string line; + std::string phone; + + while (std::getline(is, line)) { + std::istringstream iss(line); + + token_list.clear(); + + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { + SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str()); + return; + } + + while (iss >> phone) { + token_list.push_back(std::move(phone)); + } + + std::vector ids = ConvertTokensToIds(token2id_, token_list); + if (ids.empty()) { + continue; + } + word2ids_.insert({std::move(word), std::move(ids)}); + } + + // process punctuations + std::vector punctuation_list; + SplitStringToVector(punctuations, " ", false, &punctuation_list); + for (auto &s : punctuation_list) { + punctuations_.insert(std::move(s)); + } +} + +std::vector Lexicon::ConvertTextToTokenIds( + const std::string &_text) const { + std::string text(_text); + ToLowerCase(&text); + + std::vector words; + SplitStringToVector(text, " ", false, &words); + + std::vector ans; + for (auto w : words) { + std::vector prefix; + while (!w.empty() && punctuations_.count(std::string(1, w[0]))) { + // if w begins with a punctuation + prefix.push_back(token2id_.at(std::string(1, w[0]))); + w = std::string(w.begin() + 1, w.end()); + } + + std::vector suffix; + while (!w.empty() && punctuations_.count(std::string(1, w.back()))) { + suffix.push_back(token2id_.at(std::string(1, w.back()))); + w = std::string(w.begin(), w.end() - 1); + } + + if (!word2ids_.count(w)) { + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); + continue; + } + + const auto &token_ids = word2ids_.at(w); + ans.insert(ans.end(), prefix.begin(), prefix.end()); + ans.insert(ans.end(), token_ids.begin(), token_ids.end()); + ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h new file mode 100644 index 00000000..2e746ada --- /dev/null +++ b/sherpa-onnx/csrc/lexicon.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/lexicon.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_LEXICON_H_ +#define SHERPA_ONNX_CSRC_LEXICON_H_ + +#include +#include +#include +#include +#include + +namespace sherpa_onnx { + +class Lexicon { + public: + Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations); + + std::vector ConvertTextToTokenIds(const std::string &text) const; + + private: + std::unordered_map> word2ids_; + std::unordered_set punctuations_; + std::unordered_map token2id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_LEXICON_H_ diff --git a/sherpa-onnx/csrc/offline-tts-impl.cc b/sherpa-onnx/csrc/offline-tts-impl.cc new file mode 100644 index 00000000..f260499b --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-impl.cc @@ -0,0 +1,19 @@ +// sherpa-onnx/csrc/offline-tts-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-impl.h" + +#include + +#include "sherpa-onnx/csrc/offline-tts-vits-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr OfflineTtsImpl::Create( + const OfflineTtsConfig &config) { + // TODO(fangjun): Support other types + return std::make_unique(config); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-impl.h b/sherpa-onnx/csrc/offline-tts-impl.h new file mode 100644 index 00000000..877dd11a --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-impl.h @@ -0,0 +1,26 @@ +// sherpa-onnx/csrc/offline-tts-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ + +#include +#include + +#include "sherpa-onnx/csrc/offline-tts.h" + +namespace sherpa_onnx { + +class OfflineTtsImpl { + public: + virtual ~OfflineTtsImpl() = default; + + static std::unique_ptr Create(const OfflineTtsConfig &config); + + virtual GeneratedAudio Generate(const std::string &text) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-tts-model-config.cc b/sherpa-onnx/csrc/offline-tts-model-config.cc new file mode 100644 index 00000000..f38c681a --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-model-config.cc @@ -0,0 +1,45 @@ +// sherpa-onnx/csrc/offline-tts-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-model-config.h" + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTtsModelConfig::Register(ParseOptions *po) { + vits.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineTtsModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + return vits.Validate(); +} + +std::string OfflineTtsModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsModelConfig("; + os << "vits=" << vits.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-model-config.h b/sherpa-onnx/csrc/offline-tts-model-config.h new file mode 100644 index 00000000..bee50ba1 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-model-config.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-tts-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsModelConfig { + OfflineTtsVitsModelConfig vits; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineTtsModelConfig() = default; + + OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, + int32_t num_threads, bool debug, + const std::string &provider) + : vits(vits), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h new file mode 100644 index 00000000..e9b94064 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -0,0 +1,77 @@ +// sherpa-onnx/csrc/offline-tts-vits-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/lexicon.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-tts-impl.h" +#include "sherpa-onnx/csrc/offline-tts-vits-model.h" + +namespace sherpa_onnx { + +class OfflineTtsVitsImpl : public OfflineTtsImpl { + public: + explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) + : model_(std::make_unique(config.model)), + lexicon_(config.model.vits.lexicon, config.model.vits.tokens, + model_->Punctuations()) { + SHERPA_ONNX_LOGE("config: %s\n", config.ToString().c_str()); + } + + GeneratedAudio Generate(const std::string &text) const override { + std::vector x = lexicon_.ConvertTextToTokenIds(text); + if (x.empty()) { + SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); + return {}; + } + + if (model_->AddBlank()) { + std::vector buffer(x.size() * 2 + 1); + int32_t i = 1; + for (auto k : x) { + buffer[i] = k; + i += 2; + } + x = std::move(buffer); + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape = {1, static_cast(x.size())}; + Ort::Value x_tensor = Ort::Value::CreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + Ort::Value audio = model_->Run(std::move(x_tensor)); + + std::vector audio_shape = + audio.GetTensorTypeAndShapeInfo().GetShape(); + + int64_t total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } + + const float *p = audio.GetTensorData(); + + GeneratedAudio ans; + ans.sample_rate = model_->SampleRate(); + ans.samples = std::vector(p, p + total); + return ans; + } + + private: + std::unique_ptr model_; + Lexicon lexicon_; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc new file mode 100644 index 00000000..5bcb7f0b --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -0,0 +1,63 @@ +// sherpa-onnx/csrc/offline-tts-vits-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { + po->Register("vits-model", &model, "Path to VITS model"); + po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); + po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); +} + +bool OfflineTtsVitsModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--vits-model: %s does not exist", model.c_str()); + return false; + } + + if (lexicon.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-lexicon"); + return false; + } + + if (!FileExists(lexicon)) { + SHERPA_ONNX_LOGE("--vits-lexicon: %s does not exist", lexicon.c_str()); + return false; + } + + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--vits-tokens: %s does not exist", tokens.c_str()); + return false; + } + + return true; +} + +std::string OfflineTtsVitsModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsVitsModelConfig("; + os << "model=\"" << model << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "tokens=\"" << tokens << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.h b/sherpa-onnx/csrc/offline-tts-vits-model-config.h new file mode 100644 index 00000000..c8f09759 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.h @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-tts-vits-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsVitsModelConfig { + std::string model; + std::string lexicon; + std::string tokens; + + OfflineTtsVitsModelConfig() = default; + + OfflineTtsVitsModelConfig(const std::string &model, + const std::string &lexicon, + const std::string &tokens) + : model(model), lexicon(lexicon), tokens(tokens) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc new file mode 100644 index 00000000..2f636513 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -0,0 +1,135 @@ +// sherpa-onnx/csrc/offline-tts-vits-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-vits-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class OfflineTtsVitsModel::Impl { + public: + explicit Impl(const OfflineTtsModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.vits.model); + Init(buf.data(), buf.size()); + } + + Ort::Value Run(Ort::Value x) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int64_t len = x_shape[1]; + int64_t len_shape = 1; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); + + int64_t scale_shape = 1; + float noise_scale = 1; + float length_scale = 1; + float noise_scale_w = 1; + + Ort::Value noise_scale_tensor = + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( + memory_info, &noise_scale_w, 1, &scale_shape, 1); + + std::array inputs = { + std::move(x), std::move(x_length), std::move(noise_scale_tensor), + std::move(length_scale_tensor), std::move(noise_scale_w_tensor)}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + int32_t SampleRate() const { return sample_rate_; } + + bool AddBlank() const { return add_blank_; } + + std::string Punctuations() const { return punctuations_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---vits model---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank"); + SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation"); + } + + private: + OfflineTtsModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t sample_rate_; + int32_t add_blank_; + std::string punctuations_; +}; + +OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; + +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x) { + return impl_->Run(std::move(x)); +} + +int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } + +bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); } + +std::string OfflineTtsVitsModel::Punctuations() const { + return impl_->Punctuations(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h new file mode 100644 index 00000000..ca2c1c6b --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -0,0 +1,45 @@ +// sherpa-onnx/csrc/offline-tts-vits-model.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-tts-model-config.h" + +namespace sherpa_onnx { + +class OfflineTtsVitsModel { + public: + ~OfflineTtsVitsModel(); + + explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); + + /** Run the model. + * + * @param x A int64 tensor of shape (1, num_tokens) + * @return Return a float32 tensor containing audio samples. You can flatten + * it to a 1-D tensor. + */ + Ort::Value Run(Ort::Value x); + + // Sample rate of the generated audio + int32_t SampleRate() const; + + // true to insert a blank between each token + bool AddBlank() const; + + std::string Punctuations() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc new file mode 100644 index 00000000..1154f2e4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-tts.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts.h" + +#include + +#include "sherpa-onnx/csrc/offline-tts-impl.h" + +namespace sherpa_onnx { + +void OfflineTtsConfig::Register(ParseOptions *po) { model.Register(po); } + +bool OfflineTtsConfig::Validate() const { return model.Validate(); } + +std::string OfflineTtsConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +OfflineTts::OfflineTts(const OfflineTtsConfig &config) + : impl_(OfflineTtsImpl::Create(config)) {} + +OfflineTts::~OfflineTts() = default; + +GeneratedAudio OfflineTts::Generate(const std::string &text) const { + return impl_->Generate(text); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts.h b/sherpa-onnx/csrc/offline-tts.h new file mode 100644 index 00000000..0d6ce668 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts.h @@ -0,0 +1,50 @@ +// sherpa-onnx/csrc/offline-tts.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-tts-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsConfig { + OfflineTtsModelConfig model; + + OfflineTtsConfig() = default; + explicit OfflineTtsConfig(const OfflineTtsModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct GeneratedAudio { + std::vector samples; + int32_t sample_rate; +}; + +class OfflineTtsImpl; + +class OfflineTts { + public: + ~OfflineTts(); + explicit OfflineTts(const OfflineTtsConfig &config); + // @param text A string containing words separated by spaces + GeneratedAudio Generate(const std::string &text) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index e16fdaf7..23c8a9cb 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -87,4 +87,8 @@ Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index f0f25b23..25675fa2 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -8,6 +8,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/vad-model-config.h" @@ -23,6 +24,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); + +Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc new file mode 100644 index 00000000..18b520f4 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -0,0 +1,57 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include + +#include "sherpa-onnx/csrc/offline-tts.h" +#include "sherpa-onnx/csrc/parse-options.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Offline text-to-speech with sherpa-onnx + +./bin/sherpa-onnx-offline-tts \ + --vits-model /path/to/model.onnx \ + --vits-lexicon /path/to/lexicon.txt \ + --vits-tokens /path/to/tokens.txt + 'some text within single quotes' + +It will generate a file test.wav. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflineTtsConfig config; + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() == 0) { + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (po.NumArgs() > 1) { + fprintf(stderr, + "Error: Accept only one positional argument. Please use single " + "quotes to wrap your text\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + sherpa_onnx::OfflineTts tts(config); + auto audio = tts.Generate(po.GetArg(1)); + + std::ofstream os("t.pcm", std::ios::binary); + os.write(reinterpret_cast(audio.samples.data()), + sizeof(float) * audio.samples.size()); + + // sox -t raw -r 22050 -b 32 -e floating-point -c 1 ./t.pcm ./t.wav + + return 0; +}