// sherpa-onnx/csrc/offline-tts-vits-impl.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ #include #include #include #include #if __ANDROID_API__ >= 9 #include #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include "kaldifst/csrc/text-normalizer.h" #include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-tts-frontend.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/offline-tts-vits-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/piper-phonemize-lexicon.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { class OfflineTtsVitsImpl : public OfflineTtsImpl { public: explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) : config_(config), model_(std::make_unique(config.model)) { InitFrontend(); if (!config.rule_fsts.empty()) { std::vector files; SplitStringToVector(config.rule_fsts, ",", false, &files); tn_list_.reserve(files.size()); for (const auto &f : files) { if (config.model.debug) { SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); } tn_list_.push_back(std::make_unique(f)); } } } #if __ANDROID_API__ >= 9 OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config) : config_(config), model_(std::make_unique(mgr, config.model)) { InitFrontend(mgr); if (!config.rule_fsts.empty()) { std::vector files; SplitStringToVector(config.rule_fsts, ",", false, &files); tn_list_.reserve(files.size()); for (const auto &f : files) { if (config.model.debug) { SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); } auto buf = ReadFile(mgr, f); std::istrstream is(buf.data(), buf.size()); tn_list_.push_back(std::make_unique(is)); } } } #endif GeneratedAudio Generate(const std::string &_text, int64_t sid = 0, float speed = 1.0) const override { int32_t num_speakers = model_->NumSpeakers(); if (num_speakers == 0 && sid != 0) { SHERPA_ONNX_LOGE( "This is a single-speaker model and supports only sid 0. Given sid: " "%d. sid is ignored", static_cast(sid)); } if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { SHERPA_ONNX_LOGE( "This model contains only %d speakers. sid should be in the range " "[%d, %d]. Given: %d. Use sid=0", num_speakers, 0, num_speakers - 1, static_cast(sid)); sid = 0; } std::string text = _text; if (config_.model.debug) { SHERPA_ONNX_LOGE("Raw text: %s", text.c_str()); } if (!tn_list_.empty()) { for (const auto &tn : tn_list_) { text = tn->Normalize(text); if (config_.model.debug) { SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str()); } } } std::vector> x = frontend_->ConvertTextToTokenIds(text, model_->Voice()); if (x.empty() || (x.size() == 1 && x[0].empty())) { SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); return {}; } if (model_->AddBlank() && config_.model.vits.data_dir.empty()) { for (auto &k : x) { k = AddBlank(k); } } int32_t x_size = static_cast(x.size()); if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { return Process(x, sid, speed); } // the input text is too long, we process sentences within it in batches // to avoid OOM. Batch size is config_.max_num_sentences std::vector> batch; int32_t batch_size = config_.max_num_sentences; batch.reserve(config_.max_num_sentences); int32_t num_batches = x_size / batch_size; if (config_.model.debug) { SHERPA_ONNX_LOGE( "Text is too long. Split it into %d batches. batch size: %d. Number " "of sentences: %d", num_batches, batch_size, x_size); } GeneratedAudio ans; int32_t k = 0; for (int32_t b = 0; b != num_batches; ++b) { batch.clear(); for (int32_t i = 0; i != batch_size; ++i, ++k) { batch.push_back(std::move(x[k])); } auto audio = Process(batch, sid, speed); ans.sample_rate = audio.sample_rate; ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); } batch.clear(); while (k < x.size()) { batch.push_back(std::move(x[k])); ++k; } if (!batch.empty()) { auto audio = Process(batch, sid, speed); ans.sample_rate = audio.sample_rate; ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); } return ans; } private: void InitFrontend(AAssetManager *mgr) { if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( mgr, config_.model.vits.tokens, config_.model.vits.data_dir); } else { frontend_ = std::make_unique( mgr, config_.model.vits.lexicon, config_.model.vits.tokens, model_->Punctuations(), model_->Language(), config_.model.debug, model_->IsPiper()); } } void InitFrontend() { if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( config_.model.vits.tokens, config_.model.vits.data_dir); } else { frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, model_->Punctuations(), model_->Language(), config_.model.debug, model_->IsPiper()); } } std::vector AddBlank(const std::vector &x) const { // we assume the blank ID is 0 std::vector buffer(x.size() * 2 + 1); int32_t i = 1; for (auto k : x) { buffer[i] = k; i += 2; } return buffer; } GeneratedAudio Process(const std::vector> &tokens, int32_t sid, float speed) const { int32_t num_tokens = 0; for (const auto &k : tokens) { num_tokens += k.size(); } std::vector x; x.reserve(num_tokens); for (const auto &k : tokens) { x.insert(x.end(), k.begin(), k.end()); } auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); std::array x_shape = {1, static_cast(x.size())}; Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed); std::vector audio_shape = audio.GetTensorTypeAndShapeInfo().GetShape(); int64_t total = 1; // The output shape may be (1, 1, total) or (1, total) or (total,) for (auto i : audio_shape) { total *= i; } const float *p = audio.GetTensorData(); GeneratedAudio ans; ans.sample_rate = model_->SampleRate(); ans.samples = std::vector(p, p + total); return ans; } private: OfflineTtsConfig config_; std::unique_ptr model_; std::vector> tn_list_; std::unique_ptr frontend_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_