Add C++ runtime for MeloTTS (#1138)
This commit is contained in:
@@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
|
||||
|
||||
void SherpaOnnxOfflineRecognizerSetConfig(
|
||||
const SherpaOnnxOfflineRecognizer *recognizer,
|
||||
const SherpaOnnxOfflineRecognizerConfig *config){
|
||||
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
|
||||
convertConfig(config);
|
||||
recognizer->impl->SetConfig(recognizer_config);
|
||||
recognizer->impl->SetConfig(recognizer_config);
|
||||
}
|
||||
|
||||
void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
|
||||
@@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
||||
pText[text.size()] = 0;
|
||||
r->text = pText;
|
||||
|
||||
//lang
|
||||
// lang
|
||||
const auto &lang = result.lang;
|
||||
char *c_lang = new char[lang.size() + 1];
|
||||
std::copy(lang.begin(), lang.end(), c_lang);
|
||||
@@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
|
||||
}
|
||||
delete[] r->matches;
|
||||
delete r;
|
||||
};
|
||||
}
|
||||
|
||||
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
|
||||
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
|
||||
|
||||
@@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
const char *text;
|
||||
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
//
|
||||
// It is NULL if the model does not support timestamps
|
||||
float *timestamps;
|
||||
@@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
*/
|
||||
const char *json;
|
||||
|
||||
//return recognized language
|
||||
// return recognized language
|
||||
const char *lang;
|
||||
|
||||
} SherpaOnnxOfflineRecognizerResult;
|
||||
|
||||
/// Get the result of the offline stream.
|
||||
|
||||
@@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND sources
|
||||
jieba-lexicon.cc
|
||||
lexicon.cc
|
||||
melo-tts-lexicon.cc
|
||||
offline-tts-character-frontend.cc
|
||||
offline-tts-frontend.cc
|
||||
offline-tts-impl.cc
|
||||
offline-tts-model-config.cc
|
||||
offline-tts-vits-model-config.cc
|
||||
|
||||
@@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) {
|
||||
std::vector<std::string> words;
|
||||
std::vector<cppjieba::Word> jiebawords;
|
||||
|
||||
std::string s = "他来到了网易杭研大厦";
|
||||
std::string s = "他来到了网易杭研大厦。How are you?";
|
||||
std::cout << s << std::endl;
|
||||
std::cout << "[demo] Cut With HMM" << std::endl;
|
||||
jieba.Cut(s, words, true);
|
||||
|
||||
@@ -17,6 +17,7 @@ namespace sherpa_onnx {
|
||||
|
||||
// implemented in ./lexicon.cc
|
||||
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
|
||||
|
||||
std::vector<int32_t> ConvertTokensToIds(
|
||||
const std::unordered_map<std::string, int32_t> &token2id,
|
||||
const std::vector<std::string> &tokens);
|
||||
@@ -53,8 +54,7 @@ class JiebaLexicon::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
const std::string &text) const {
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &text) const {
|
||||
// see
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
|
||||
std::regex punct_re{":|、|;"};
|
||||
@@ -87,7 +87,7 @@ class JiebaLexicon::Impl {
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
int32_t blank = token2id_.at(" ");
|
||||
@@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon,
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
|
||||
debug)) {}
|
||||
|
||||
std::vector<std::vector<int64_t>> JiebaLexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*unused_voice = ""*/) const {
|
||||
return impl_->ConvertTextToTokenIds(text);
|
||||
}
|
||||
|
||||
@@ -10,11 +10,6 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
|
||||
|
||||
@@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend {
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
JiebaLexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||
const std::string &tokens, const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text,
|
||||
const std::string &unused_voice = "") const override;
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
|
||||
switch (language_) {
|
||||
case Language::kChinese:
|
||||
@@ -187,7 +187,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
@@ -209,7 +209,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
int32_t blank = -1;
|
||||
@@ -288,7 +288,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
std::vector<TokenIDs> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
@@ -311,7 +311,7 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||
|
||||
int32_t blank = token2id_.at(" ");
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
std::vector<int64_t> this_sentence;
|
||||
|
||||
for (const auto &w : words) {
|
||||
|
||||
@@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend {
|
||||
const std::string &language, bool debug = false);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIdsNotChinese(
|
||||
const std::string &text) const;
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIdsChinese(
|
||||
const std::string &text) const;
|
||||
|
||||
void InitLanguage(const std::string &lang);
|
||||
|
||||
266
sherpa-onnx/csrc/melo-tts-lexicon.cc
Normal file
266
sherpa-onnx/csrc/melo-tts-lexicon.cc
Normal file
@@ -0,0 +1,266 @@
|
||||
// sherpa-onnx/csrc/melo-tts-lexicon.cc
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <regex> // NOLINT
|
||||
#include <utility>
|
||||
|
||||
#include "cppjieba/Jieba.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// implemented in ./lexicon.cc
|
||||
std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is);
|
||||
|
||||
std::vector<int32_t> ConvertTokensToIds(
|
||||
const std::unordered_map<std::string, int32_t> &token2id,
|
||||
const std::vector<std::string> &tokens);
|
||||
|
||||
class MeloTtsLexicon::Impl {
|
||||
public:
|
||||
Impl(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
|
||||
: meta_data_(meta_data), debug_(debug) {
|
||||
std::string dict = dict_dir + "/jieba.dict.utf8";
|
||||
std::string hmm = dict_dir + "/hmm_model.utf8";
|
||||
std::string user_dict = dict_dir + "/user.dict.utf8";
|
||||
std::string idf = dict_dir + "/idf.utf8";
|
||||
std::string stop_word = dict_dir + "/stop_words.utf8";
|
||||
|
||||
AssertFileExists(dict);
|
||||
AssertFileExists(hmm);
|
||||
AssertFileExists(user_dict);
|
||||
AssertFileExists(idf);
|
||||
AssertFileExists(stop_word);
|
||||
|
||||
jieba_ =
|
||||
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
|
||||
|
||||
{
|
||||
std::ifstream is(tokens);
|
||||
InitTokens(is);
|
||||
}
|
||||
|
||||
{
|
||||
std::ifstream is(lexicon);
|
||||
InitLexicon(is);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
|
||||
std::string text = ToLowerCase(_text);
|
||||
// see
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
|
||||
std::regex punct_re{":|、|;"};
|
||||
std::string s = std::regex_replace(text, punct_re, ",");
|
||||
|
||||
std::regex punct_re2("。");
|
||||
s = std::regex_replace(s, punct_re2, ".");
|
||||
|
||||
std::regex punct_re3("?");
|
||||
s = std::regex_replace(s, punct_re3, "?");
|
||||
|
||||
std::regex punct_re4("!");
|
||||
s = std::regex_replace(s, punct_re4, "!");
|
||||
|
||||
std::vector<std::string> words;
|
||||
bool is_hmm = true;
|
||||
jieba_->Cut(text, words, is_hmm);
|
||||
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
|
||||
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
|
||||
|
||||
std::ostringstream os;
|
||||
std::string sep = "";
|
||||
for (const auto &w : words) {
|
||||
os << sep << w;
|
||||
sep = "_";
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<TokenIDs> ans;
|
||||
TokenIDs this_sentence;
|
||||
|
||||
int32_t blank = token2id_.at("_");
|
||||
for (const auto &w : words) {
|
||||
auto ids = ConvertWordToIds(w);
|
||||
if (ids.tokens.empty()) {
|
||||
SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
this_sentence.tokens.insert(this_sentence.tokens.end(),
|
||||
ids.tokens.begin(), ids.tokens.end());
|
||||
this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
|
||||
ids.tones.end());
|
||||
|
||||
if (w == "." || w == "!" || w == "?" || w == ",") {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
this_sentence = {};
|
||||
}
|
||||
} // for (const auto &w : words)
|
||||
|
||||
if (!this_sentence.tokens.empty()) {
|
||||
ans.push_back(std::move(this_sentence));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
TokenIDs ConvertWordToIds(const std::string &w) const {
|
||||
if (word2ids_.count(w)) {
|
||||
return word2ids_.at(w);
|
||||
}
|
||||
|
||||
if (token2id_.count(w)) {
|
||||
return {{token2id_.at(w)}, {0}};
|
||||
}
|
||||
|
||||
TokenIDs ans;
|
||||
|
||||
std::vector<std::string> words = SplitUtf8(w);
|
||||
for (const auto &word : words) {
|
||||
if (word2ids_.count(word)) {
|
||||
auto ids = ConvertWordToIds(word);
|
||||
ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
|
||||
ids.tokens.end());
|
||||
ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void InitTokens(std::istream &is) {
|
||||
token2id_ = ReadTokens(is);
|
||||
token2id_[" "] = token2id_["_"];
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> puncts = {
|
||||
{",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}};
|
||||
|
||||
for (const auto &p : puncts) {
|
||||
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
|
||||
token2id_[p.second] = token2id_[p.first];
|
||||
}
|
||||
|
||||
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
|
||||
token2id_[p.first] = token2id_[p.second];
|
||||
}
|
||||
}
|
||||
|
||||
if (!token2id_.count("、") && token2id_.count(",")) {
|
||||
token2id_["、"] = token2id_[","];
|
||||
}
|
||||
}
|
||||
|
||||
void InitLexicon(std::istream &is) {
|
||||
std::string word;
|
||||
std::vector<std::string> token_list;
|
||||
|
||||
std::vector<std::string> phone_list;
|
||||
std::vector<int64_t> tone_list;
|
||||
|
||||
std::string line;
|
||||
std::string phone;
|
||||
int32_t line_num = 0;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
++line_num;
|
||||
|
||||
std::istringstream iss(line);
|
||||
|
||||
token_list.clear();
|
||||
phone_list.clear();
|
||||
tone_list.clear();
|
||||
|
||||
iss >> word;
|
||||
ToLowerCase(&word);
|
||||
|
||||
if (word2ids_.count(word)) {
|
||||
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
|
||||
word.c_str(), line_num, line.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
while (iss >> phone) {
|
||||
token_list.push_back(std::move(phone));
|
||||
}
|
||||
|
||||
if ((token_list.size() & 1) != 0) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t num_phones = token_list.size() / 2;
|
||||
phone_list.reserve(num_phones);
|
||||
tone_list.reserve(num_phones);
|
||||
|
||||
for (int32_t i = 0; i != num_phones; ++i) {
|
||||
phone_list.push_back(std::move(token_list[i]));
|
||||
tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
|
||||
if (tone_list.back() < 0 || tone_list.back() > 50) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
|
||||
if (ids.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ids.size() != num_phones) {
|
||||
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> ids64{ids.begin(), ids.end()};
|
||||
|
||||
word2ids_.insert(
|
||||
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
|
||||
}
|
||||
|
||||
word2ids_["呣"] = word2ids_["母"];
|
||||
word2ids_["嗯"] = word2ids_["恩"];
|
||||
}
|
||||
|
||||
private:
|
||||
// lexicon.txt is saved in word2ids_
|
||||
std::unordered_map<std::string, TokenIDs> word2ids_;
|
||||
|
||||
// tokens.txt is saved in token2id_
|
||||
std::unordered_map<std::string, int32_t> token2id_;
|
||||
|
||||
OfflineTtsVitsModelMetaData meta_data_;
|
||||
|
||||
std::unique_ptr<cppjieba::Jieba> jieba_;
|
||||
bool debug_ = false;
|
||||
};
|
||||
|
||||
MeloTtsLexicon::~MeloTtsLexicon() = default;
|
||||
|
||||
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
|
||||
const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data,
|
||||
bool debug)
|
||||
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
|
||||
debug)) {}
|
||||
|
||||
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string & /*unused_voice = ""*/) const {
|
||||
return impl_->ConvertTextToTokenIds(text);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
36
sherpa-onnx/csrc/melo-tts-lexicon.h
Normal file
36
sherpa-onnx/csrc/melo-tts-lexicon.h
Normal file
@@ -0,0 +1,36 @@
|
||||
// sherpa-onnx/csrc/melo-tts-lexicon.h
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class MeloTtsLexicon : public OfflineTtsFrontend {
|
||||
public:
|
||||
~MeloTtsLexicon() override;
|
||||
MeloTtsLexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &dict_dir,
|
||||
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
|
||||
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text,
|
||||
const std::string &unused_voice = "") const override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_
|
||||
@@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend(
|
||||
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
const std::string &_text, const std::string & /*voice = ""*/) const {
|
||||
// see
|
||||
// https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87
|
||||
@@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds(
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
|
||||
std::u32string s = conv.from_bytes(text);
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
|
||||
std::vector<int64_t> this_sentence;
|
||||
if (add_blank) {
|
||||
|
||||
@@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend {
|
||||
* If a frontend does not support splitting the text into
|
||||
* sentences, the resulting vector contains only one subvector.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
|
||||
34
sherpa-onnx/csrc/offline-tts-frontend.cc
Normal file
34
sherpa-onnx/csrc/offline-tts-frontend.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-tts-frontend.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string TokenIDs::ToString() const {
|
||||
std::ostringstream os;
|
||||
os << "TokenIDs(";
|
||||
os << "tokens=[";
|
||||
std::string sep;
|
||||
for (auto i : tokens) {
|
||||
os << sep << i;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "tones=[";
|
||||
sep = {};
|
||||
for (auto i : tones) {
|
||||
os << sep << i;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -8,8 +8,28 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct TokenIDs {
|
||||
TokenIDs() = default;
|
||||
|
||||
/*implicit*/ TokenIDs(const std::vector<int64_t> &tokens) // NOLINT
|
||||
: tokens{tokens} {}
|
||||
|
||||
TokenIDs(const std::vector<int64_t> &tokens,
|
||||
const std::vector<int64_t> &tones)
|
||||
: tokens{tokens}, tones{tones} {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
// Used only in MeloTTS
|
||||
std::vector<int64_t> tones;
|
||||
};
|
||||
|
||||
class OfflineTtsFrontend {
|
||||
public:
|
||||
virtual ~OfflineTtsFrontend() = default;
|
||||
@@ -26,7 +46,7 @@ class OfflineTtsFrontend {
|
||||
* If a frontend does not support splitting the text into sentences,
|
||||
* the resulting vector contains only one subvector.
|
||||
*/
|
||||
virtual std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
virtual std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/csrc/jieba-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x =
|
||||
std::vector<TokenIDs> token_ids =
|
||||
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
|
||||
|
||||
if (x.empty() || (x.size() == 1 && x[0].empty())) {
|
||||
if (token_ids.empty() ||
|
||||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x;
|
||||
std::vector<std::vector<int64_t>> tones;
|
||||
|
||||
x.reserve(token_ids.size());
|
||||
|
||||
for (auto &i : token_ids) {
|
||||
x.push_back(std::move(i.tokens));
|
||||
}
|
||||
|
||||
if (!token_ids[0].tones.empty()) {
|
||||
tones.reserve(token_ids.size());
|
||||
for (auto &i : token_ids) {
|
||||
tones.push_back(std::move(i.tones));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): add blank inside the frontend, not here
|
||||
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
|
||||
meta_data.frontend != "characters") {
|
||||
for (auto &k : x) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
|
||||
for (auto &k : tones) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t x_size = static_cast<int32_t>(x.size());
|
||||
|
||||
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
||||
auto ans = Process(x, sid, speed);
|
||||
auto ans = Process(x, tones, sid, speed);
|
||||
if (callback) {
|
||||
callback(ans.samples.data(), ans.samples.size(), 1.0);
|
||||
}
|
||||
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
|
||||
// the input text is too long, we process sentences within it in batches
|
||||
// to avoid OOM. Batch size is config_.max_num_sentences
|
||||
std::vector<std::vector<int64_t>> batch;
|
||||
std::vector<std::vector<int64_t>> batch_x;
|
||||
std::vector<std::vector<int64_t>> batch_tones;
|
||||
|
||||
int32_t batch_size = config_.max_num_sentences;
|
||||
batch.reserve(config_.max_num_sentences);
|
||||
batch_x.reserve(config_.max_num_sentences);
|
||||
batch_tones.reserve(config_.max_num_sentences);
|
||||
int32_t num_batches = x_size / batch_size;
|
||||
|
||||
if (config_.model.debug) {
|
||||
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
int32_t k = 0;
|
||||
|
||||
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
for (int32_t i = 0; i != batch_size; ++i, ++k) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
}
|
||||
|
||||
auto audio = Process(batch, sid, speed);
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
while (k < static_cast<int32_t>(x.size()) && should_continue) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
|
||||
++k;
|
||||
}
|
||||
|
||||
if (!batch.empty()) {
|
||||
auto audio = Process(batch, sid, speed);
|
||||
if (!batch_x.empty()) {
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
if (meta_data.frontend == "characters") {
|
||||
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
|
||||
config_.model.vits.tokens, meta_data);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
|
||||
meta_data.is_melo_tts) {
|
||||
frontend_ = std::make_unique<MeloTtsLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
config_.model.vits.dict_dir, model_->GetMetaData(),
|
||||
config_.model.debug);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
|
||||
frontend_ = std::make_unique<JiebaLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
|
||||
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
|
||||
const std::vector<std::vector<int64_t>> &tones,
|
||||
int32_t sid, float speed) const {
|
||||
int32_t num_tokens = 0;
|
||||
for (const auto &k : tokens) {
|
||||
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
x.insert(x.end(), k.begin(), k.end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> tone_list;
|
||||
if (!tones.empty()) {
|
||||
tone_list.reserve(num_tokens);
|
||||
for (const auto &k : tones) {
|
||||
tone_list.insert(tone_list.end(), k.begin(), k.end());
|
||||
}
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
Ort::Value tones_tensor{nullptr};
|
||||
if (!tones.empty()) {
|
||||
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
|
||||
tone_list.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
}
|
||||
|
||||
Ort::Value audio{nullptr};
|
||||
if (tones.empty()) {
|
||||
audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
} else {
|
||||
audio =
|
||||
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
|
||||
}
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData {
|
||||
bool is_piper = false;
|
||||
bool is_coqui = false;
|
||||
bool is_icefall = false;
|
||||
bool is_melo_tts = false;
|
||||
|
||||
// for Chinese TTS models from
|
||||
// https://github.com/Plachtaa/VITS-fast-fine-tuning
|
||||
@@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData {
|
||||
int32_t use_eos_bos = 0;
|
||||
int32_t pad_id = 0;
|
||||
|
||||
// for melo tts
|
||||
int32_t speaker_id = 0;
|
||||
int32_t version = 0;
|
||||
|
||||
std::string punctuations;
|
||||
std::string language;
|
||||
std::string voice;
|
||||
|
||||
@@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl {
|
||||
return RunVits(std::move(x), sid, speed);
|
||||
}
|
||||
|
||||
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) {
|
||||
// For MeloTTS, we hardcode sid to the one contained in the meta data
|
||||
sid = meta_data_.speaker_id;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<int64_t> x_shape = x.GetTensorTypeAndShapeInfo().GetShape();
|
||||
if (x_shape[0] != 1) {
|
||||
SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d",
|
||||
static_cast<int32_t>(x_shape[0]));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int64_t len = x_shape[1];
|
||||
int64_t len_shape = 1;
|
||||
|
||||
Ort::Value x_length =
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
int64_t scale_shape = 1;
|
||||
float noise_scale = config_.vits.noise_scale;
|
||||
float length_scale = config_.vits.length_scale;
|
||||
float noise_scale_w = config_.vits.noise_scale_w;
|
||||
|
||||
if (speed != 1 && speed > 0) {
|
||||
length_scale = 1. / speed;
|
||||
}
|
||||
|
||||
Ort::Value noise_scale_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value length_scale_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &length_scale, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &noise_scale_w, 1, &scale_shape, 1);
|
||||
|
||||
Ort::Value sid_tensor =
|
||||
Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(7);
|
||||
inputs.push_back(std::move(x));
|
||||
inputs.push_back(std::move(x_length));
|
||||
inputs.push_back(std::move(tones));
|
||||
inputs.push_back(std::move(sid_tensor));
|
||||
inputs.push_back(std::move(noise_scale_tensor));
|
||||
inputs.push_back(std::move(length_scale_tensor));
|
||||
inputs.push_back(std::move(noise_scale_w_tensor));
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }
|
||||
|
||||
private:
|
||||
@@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
|
||||
0);
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id",
|
||||
0);
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0);
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
|
||||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
|
||||
"punctuation", "");
|
||||
@@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl {
|
||||
if (comment.find("icefall") != std::string::npos) {
|
||||
meta_data_.is_icefall = true;
|
||||
}
|
||||
|
||||
if (comment.find("melo") != std::string::npos) {
|
||||
meta_data_.is_melo_tts = true;
|
||||
int32_t expected_version = 2;
|
||||
if (meta_data_.version < expected_version) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please download the latest MeloTTS model and retry. Current "
|
||||
"version: %d. Expected version: %d",
|
||||
meta_data_.version, expected_version);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// NOTE(fangjun):
|
||||
// version 0 is the first version
|
||||
// version 2: add jieba=1 to the metadata
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
|
||||
@@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
|
||||
return impl_->Run(std::move(x), sid, speed);
|
||||
}
|
||||
|
||||
Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
|
||||
int64_t sid /*= 0*/,
|
||||
float speed /*= 1.0*/) {
|
||||
return impl_->Run(std::move(x), std::move(tones), sid, speed);
|
||||
}
|
||||
|
||||
const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
@@ -40,6 +40,10 @@ class OfflineTtsVitsModel {
|
||||
*/
|
||||
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
|
||||
|
||||
// This is for MeloTTS
|
||||
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
|
||||
float speed = 1.0);
|
||||
|
||||
const OfflineTtsVitsModelMetaData &GetMetaData() const;
|
||||
|
||||
private:
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
@@ -36,7 +36,6 @@ class OfflineWhisperDecoder {
|
||||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||
|
||||
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
|
||||
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
|
||||
void OfflineWhisperGreedySearchDecoder::SetConfig(
|
||||
const OfflineWhisperModelConfig &config) {
|
||||
config_ = config;
|
||||
}
|
||||
|
||||
@@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
|
||||
const auto &id2lang = model_->GetID2Lang();
|
||||
if (id2lang.count(initial_tokens[1])) {
|
||||
ans[0].lang = id2lang.at(initial_tokens[1]);
|
||||
ans[0].lang = id2lang.at(initial_tokens[1]);
|
||||
} else {
|
||||
ans[0].lang = "";
|
||||
ans[0].lang = "";
|
||||
}
|
||||
|
||||
ans[0].tokens = std::move(predicted_tokens);
|
||||
|
||||
@@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print1D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
const T *d = v->GetTensorData<T>();
|
||||
std::ostringstream os;
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0]); ++i) {
|
||||
fprintf(stderr, "%.3f ", d[i]);
|
||||
os << *d << " ";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
os << "\n";
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
template void Print1D<int64_t>(Ort::Value *v);
|
||||
template void Print1D<float>(Ort::Value *v);
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print2D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
|
||||
Ort::Value View(Ort::Value *v);
|
||||
|
||||
// Print a 1-D tensor to stderr
|
||||
template <typename T = float>
|
||||
void Print1D(Ort::Value *v);
|
||||
|
||||
// Print a 2-D tensor to stderr
|
||||
|
||||
@@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice /*= ""*/) const {
|
||||
piper::eSpeakPhonemeConfig config;
|
||||
|
||||
@@ -232,7 +232,7 @@ std::vector<std::vector<int64_t>> PiperPhonemizeLexicon::ConvertTextToTokenIds(
|
||||
piper::phonemize_eSpeak(text, config, phonemes);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ans;
|
||||
std::vector<TokenIDs> ans;
|
||||
|
||||
std::vector<int64_t> phoneme_ids;
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend {
|
||||
const OfflineTtsVitsModelMetaData &meta_data);
|
||||
#endif
|
||||
|
||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||
std::vector<TokenIDs> ConvertTextToTokenIds(
|
||||
const std::string &text, const std::string &voice = "") const override;
|
||||
|
||||
private:
|
||||
|
||||
@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
const std::string &provider_str,
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
int32_t num_threads, const std::string &provider_str,
|
||||
const ProviderConfig *provider_config = nullptr) {
|
||||
Provider p = StringToProvider(provider_str);
|
||||
|
||||
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
}
|
||||
case Provider::kTRT: {
|
||||
if (provider_config == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
exit(1);
|
||||
}
|
||||
auto trt_config = provider_config->trt_config;
|
||||
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
std::to_string(trt_config.trt_max_partition_iterations);
|
||||
auto trt_min_subgraph_size =
|
||||
std::to_string(trt_config.trt_min_subgraph_size);
|
||||
auto trt_fp16_enable =
|
||||
std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_detailed_build_log =
|
||||
std::to_string(trt_config.trt_detailed_build_log);
|
||||
auto trt_engine_cache_enable =
|
||||
std::to_string(trt_config.trt_engine_cache_enable);
|
||||
auto trt_timing_cache_enable =
|
||||
std::to_string(trt_config.trt_timing_cache_enable);
|
||||
auto trt_dump_subgraphs =
|
||||
std::to_string(trt_config.trt_dump_subgraphs);
|
||||
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
|
||||
std::vector<TrtPairs> trt_options = {
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
|
||||
};
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations",
|
||||
trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
|
||||
// ToDo : Trt configs
|
||||
// "trt_int8_enable"
|
||||
// "trt_int8_use_native_calibration_table"
|
||||
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
if (provider_config != nullptr) {
|
||||
options.device_id = provider_config->device;
|
||||
options.cudnn_conv_algo_search =
|
||||
OrtCudnnConvAlgoSearch(provider_config->cuda_config
|
||||
.cudnn_conv_algo_search);
|
||||
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
|
||||
provider_config->cuda_config.cudnn_conv_algo_search);
|
||||
} else {
|
||||
options.device_id = 0;
|
||||
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
|
||||
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type) {
|
||||
const std::string &model_type) {
|
||||
/*
|
||||
Transducer models : Only encoder will run with tensorrt,
|
||||
decoder and joiner will run with cuda
|
||||
*/
|
||||
if(config.provider_config.provider == "trt" &&
|
||||
if (config.provider_config.provider == "trt" &&
|
||||
(model_type == "decoder" || model_type == "joiner")) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
"cuda", &config.provider_config);
|
||||
return GetSessionOptionsImpl(config.num_threads, "cuda",
|
||||
&config.provider_config);
|
||||
}
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_SESSION_H_
|
||||
#define SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
@@ -25,7 +27,7 @@ namespace sherpa_onnx {
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type);
|
||||
const std::string &model_type);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(UTF8, Case1) {
|
||||
std::string hello = "你好, 早上好!世界. hello!。Hallo";
|
||||
std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?";
|
||||
std::vector<std::string> ss = SplitUtf8(hello);
|
||||
for (const auto &s : ss) {
|
||||
std::cout << s << "\n";
|
||||
|
||||
Reference in New Issue
Block a user