Add jni interface and kotlin API examples for TTS. (#381)

This commit is contained in:
Fangjun Kuang
2023-10-23 12:31:54 +08:00
committed by GitHub
parent b582f6c115
commit 0fdb2044e1
15 changed files with 453 additions and 36 deletions

View File

@@ -10,7 +10,15 @@
#include <sstream>
#include <utility>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@@ -22,11 +30,9 @@ static void ToLowerCase(std::string *in_out) {
// Note: We don't use SymbolTable here since tokens may contain a blank
// in the first column
static std::unordered_map<std::string, int32_t> ReadTokens(
const std::string &tokens) {
static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
std::unordered_map<std::string, int32_t> token2id;
std::ifstream is(tokens);
std::string line;
std::string sym;
@@ -80,11 +86,43 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
InitPunctuations(punctuations);
}
#if __ANDROID_API__ >= 9
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug /*= false*/)
: debug_(debug) {
InitLanguage(language);
{
auto buf = ReadFile(mgr, tokens);
std::istrstream is(buf.data(), buf.size());
InitTokens(is);
}
{
auto buf = ReadFile(mgr, lexicon);
std::istrstream is(buf.data(), buf.size());
InitLexicon(is);
}
InitPunctuations(punctuations);
}
#endif
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
@@ -192,9 +230,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
return ans;
}
void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
}
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
@@ -209,9 +245,7 @@ void Lexicon::InitLanguage(const std::string &_lang) {
}
}
void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);
void Lexicon::InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::string line;

View File

@@ -6,11 +6,17 @@
#define SHERPA_ONNX_CSRC_LEXICON_H_
#include <cstdint>
#include <iostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
namespace sherpa_onnx {
// TODO(fangjun): Refactor it to an abstract class
@@ -20,6 +26,12 @@ class Lexicon {
const std::string &punctuations, const std::string &language,
bool debug = false);
#if __ANDROID_API__ >= 9
Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug = false);
#endif
std::vector<int64_t> ConvertTextToTokenIds(const std::string &text) const;
private:
@@ -30,8 +42,8 @@ class Lexicon {
const std::string &text) const;
void InitLanguage(const std::string &lang);
void InitTokens(const std::string &tokens);
void InitLexicon(const std::string &lexicon);
void InitTokens(std::istream &is);
void InitLexicon(std::istream &is);
void InitPunctuations(const std::string &punctuations);
private:

View File

@@ -16,4 +16,12 @@ std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
return std::make_unique<OfflineTtsVitsImpl>(config);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
AAssetManager *mgr, const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
}
#endif
} // namespace sherpa_onnx

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-tts.h"
namespace sherpa_onnx {
@@ -18,6 +23,11 @@ class OfflineTtsImpl {
static std::unique_ptr<OfflineTtsImpl> Create(const OfflineTtsConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineTtsImpl> Create(AAssetManager *mgr,
const OfflineTtsConfig &config);
#endif
virtual GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const = 0;
};

View File

@@ -9,6 +9,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
@@ -24,6 +29,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
model_->Punctuations(), model_->Language(),
config.model.debug) {}
#if __ANDROID_API__ >= 9
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
: model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
model_->Punctuations(), model_->Language(),
config.model.debug) {}
#endif
GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0) const override {
int32_t num_speakers = model_->NumSpeakers();

View File

@@ -26,6 +26,17 @@ class OfflineTtsVitsModel::Impl {
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineTtsModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config.vits.model);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Run(Ort::Value x, int64_t sid, float speed) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -141,6 +152,12 @@ class OfflineTtsVitsModel::Impl {
OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineTtsVitsModel::OfflineTtsVitsModel(AAssetManager *mgr,
const OfflineTtsModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineTtsVitsModel::~OfflineTtsVitsModel() = default;
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,

View File

@@ -8,6 +8,11 @@
#include <memory>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
@@ -18,6 +23,9 @@ class OfflineTtsVitsModel {
~OfflineTtsVitsModel();
explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineTtsVitsModel(AAssetManager *mgr, const OfflineTtsModelConfig &config);
#endif
/** Run the model.
*

View File

@@ -26,6 +26,11 @@ std::string OfflineTtsConfig::ToString() const {
OfflineTts::OfflineTts(const OfflineTtsConfig &config)
: impl_(OfflineTtsImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
OfflineTts::OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config)
: impl_(OfflineTtsImpl::Create(mgr, config)) {}
#endif
OfflineTts::~OfflineTts() = default;
GeneratedAudio OfflineTts::Generate(const std::string &text, int64_t sid /*=0*/,

View File

@@ -9,6 +9,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
@@ -38,6 +43,11 @@ class OfflineTts {
public:
~OfflineTts();
explicit OfflineTts(const OfflineTtsConfig &config);
#if __ANDROID_API__ >= 9
OfflineTts(AAssetManager *mgr, const OfflineTtsConfig &config);
#endif
// @param text A string containing words separated by spaces
// @param sid Speaker ID. Used only for multi-speaker models, e.g., models
// trained using the VCTK dataset. It is not used for

View File

@@ -7,12 +7,13 @@
#include <cassert>
#include <fstream>
#include <sstream>
#include <strstream>
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif