Validate input sid (#369)
This commit is contained in:
@@ -20,7 +20,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
|||||||
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
||||||
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
||||||
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
|
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
|
||||||
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is ON on Linux" ON)
|
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ def main():
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
audio = tts.generate(args.text, sid=args.sid)
|
audio = tts.generate(args.text, sid=args.sid)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
|
if len(audio.samples) == 0:
|
||||||
|
print("Error in generating audios. Please read previous error messages.")
|
||||||
|
return
|
||||||
|
|
||||||
elapsed_seconds = end - start
|
elapsed_seconds = end - start
|
||||||
audio_duration = len(audio.samples) / audio.sample_rate
|
audio_duration = len(audio.samples) / audio.sample_rate
|
||||||
real_time_factor = elapsed_seconds / audio_duration
|
real_time_factor = elapsed_seconds / audio_duration
|
||||||
|
|||||||
@@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
|||||||
|
|
||||||
std::vector<int64_t> ans;
|
std::vector<int64_t> ans;
|
||||||
|
|
||||||
ans.push_back(token2id_.at("sil"));
|
auto sil = token2id_.at("sil");
|
||||||
|
auto eos = token2id_.at("eos");
|
||||||
|
|
||||||
|
ans.push_back(sil);
|
||||||
|
|
||||||
for (const auto &w : words) {
|
for (const auto &w : words) {
|
||||||
|
if (punctuations_.count(w)) {
|
||||||
|
ans.push_back(sil);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (!word2ids_.count(w)) {
|
if (!word2ids_.count(w)) {
|
||||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||||
continue;
|
continue;
|
||||||
@@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
|||||||
const auto &token_ids = word2ids_.at(w);
|
const auto &token_ids = word2ids_.at(w);
|
||||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||||
}
|
}
|
||||||
ans.push_back(token2id_.at("sil"));
|
ans.push_back(sil);
|
||||||
ans.push_back(token2id_.at("eos"));
|
ans.push_back(eos);
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
|||||||
ToLowerCase(&text);
|
ToLowerCase(&text);
|
||||||
|
|
||||||
std::vector<std::string> words = SplitUtf8(text);
|
std::vector<std::string> words = SplitUtf8(text);
|
||||||
|
int32_t blank = token2id_.at(" ");
|
||||||
|
|
||||||
std::vector<int64_t> ans;
|
std::vector<int64_t> ans;
|
||||||
for (const auto &w : words) {
|
for (const auto &w : words) {
|
||||||
@@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
|||||||
|
|
||||||
const auto &token_ids = word2ids_.at(w);
|
const auto &token_ids = word2ids_.at(w);
|
||||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||||
if (blank_ != -1) {
|
ans.push_back(blank);
|
||||||
ans.push_back(blank_);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (blank_ != -1 && !ans.empty()) {
|
if (!ans.empty()) {
|
||||||
// remove the last blank
|
// remove the last blank
|
||||||
ans.resize(ans.size() - 1);
|
ans.resize(ans.size() - 1);
|
||||||
}
|
}
|
||||||
@@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
|||||||
|
|
||||||
void Lexicon::InitTokens(const std::string &tokens) {
|
void Lexicon::InitTokens(const std::string &tokens) {
|
||||||
token2id_ = ReadTokens(tokens);
|
token2id_ = ReadTokens(tokens);
|
||||||
if (token2id_.count(" ")) {
|
|
||||||
blank_ = token2id_.at(" ");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lexicon::InitLanguage(const std::string &_lang) {
|
void Lexicon::InitLanguage(const std::string &_lang) {
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ class Lexicon {
|
|||||||
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
|
std::unordered_map<std::string, std::vector<int32_t>> word2ids_;
|
||||||
std::unordered_set<std::string> punctuations_;
|
std::unordered_set<std::string> punctuations_;
|
||||||
std::unordered_map<std::string, int32_t> token2id_;
|
std::unordered_map<std::string, int32_t> token2id_;
|
||||||
int32_t blank_ = -1; // ID for the blank token
|
|
||||||
Language language_;
|
Language language_;
|
||||||
//
|
//
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -25,6 +25,23 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
|
|
||||||
GeneratedAudio Generate(const std::string &text,
|
GeneratedAudio Generate(const std::string &text,
|
||||||
int64_t sid = 0) const override {
|
int64_t sid = 0) const override {
|
||||||
|
int32_t num_speakers = model_->NumSpeakers();
|
||||||
|
if (num_speakers == 0 && sid != 0) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"This is a single-speaker model and supports only sid 0. Given sid: "
|
||||||
|
"%d",
|
||||||
|
sid);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"This model contains only %d speakers. sid should be in the range "
|
||||||
|
"[%d, %d]. Given: %d",
|
||||||
|
num_speakers, 0, num_speakers - 1, sid);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||||
if (x.empty()) {
|
if (x.empty()) {
|
||||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
|
|
||||||
std::string Punctuations() const { return punctuations_; }
|
std::string Punctuations() const { return punctuations_; }
|
||||||
std::string Language() const { return language_; }
|
std::string Language() const { return language_; }
|
||||||
|
int32_t NumSpeakers() const { return num_speakers_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(void *model_data, size_t model_data_length) {
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
@@ -107,7 +108,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
||||||
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
SHERPA_ONNX_READ_META_DATA(add_blank_, "add_blank");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_speakers_, "n_speakers");
|
SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
|
||||||
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
SHERPA_ONNX_READ_META_DATA_STR(punctuations_, "punctuation");
|
||||||
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
|
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
|
||||||
}
|
}
|
||||||
@@ -128,7 +129,7 @@ class OfflineTtsVitsModel::Impl {
|
|||||||
|
|
||||||
int32_t sample_rate_;
|
int32_t sample_rate_;
|
||||||
int32_t add_blank_;
|
int32_t add_blank_;
|
||||||
int32_t n_speakers_;
|
int32_t num_speakers_;
|
||||||
std::string punctuations_;
|
std::string punctuations_;
|
||||||
std::string language_;
|
std::string language_;
|
||||||
};
|
};
|
||||||
@@ -152,4 +153,8 @@ std::string OfflineTtsVitsModel::Punctuations() const {
|
|||||||
|
|
||||||
std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
|
std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
|
||||||
|
|
||||||
|
int32_t OfflineTtsVitsModel::NumSpeakers() const {
|
||||||
|
return impl_->NumSpeakers();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class OfflineTtsVitsModel {
|
|||||||
|
|
||||||
std::string Punctuations() const;
|
std::string Punctuations() const;
|
||||||
std::string Language() const;
|
std::string Language() const;
|
||||||
|
int32_t NumSpeakers() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|||||||
@@ -81,6 +81,12 @@ or detailes.
|
|||||||
|
|
||||||
sherpa_onnx::OfflineTts tts(config);
|
sherpa_onnx::OfflineTts tts(config);
|
||||||
auto audio = tts.Generate(po.GetArg(1), sid);
|
auto audio = tts.Generate(po.GetArg(1), sid);
|
||||||
|
if (audio.samples.empty()) {
|
||||||
|
fprintf(
|
||||||
|
stderr,
|
||||||
|
"Error in generating audios. Please read previous error messages.\n");
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
|
bool ok = sherpa_onnx::WriteWave(output_filename, audio.sample_rate,
|
||||||
audio.samples.data(), audio.samples.size());
|
audio.samples.data(), audio.samples.size());
|
||||||
|
|||||||
Reference in New Issue
Block a user