Added progress for callback of tts generator (#712)
Co-authored-by: leohwang <leohwang@360converter.com>
This commit is contained in:
@@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
|
|||||||
return tts->impl->NumSpeakers();
|
return tts->impl->NumSpeakers();
|
||||||
}
|
}
|
||||||
|
|
||||||
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
|
static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
|
||||||
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
|
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
|
||||||
float speed) {
|
float speed, std::function<void(const float *, int32_t, float)> callback)
|
||||||
return SherpaOnnxOfflineTtsGenerateWithCallback(tts, text, sid, speed,
|
{
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
|
|
||||||
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
|
|
||||||
SherpaOnnxGeneratedAudioCallback callback) {
|
|
||||||
sherpa_onnx::GeneratedAudio audio =
|
sherpa_onnx::GeneratedAudio audio =
|
||||||
tts->impl->Generate(text, sid, speed, callback);
|
tts->impl->Generate(text, sid, speed, callback);
|
||||||
|
|
||||||
@@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
|
||||||
|
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
|
||||||
|
float speed) {
|
||||||
|
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr );
|
||||||
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
|
||||||
|
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
|
||||||
|
SherpaOnnxGeneratedAudioCallback callback) {
|
||||||
|
auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) {
|
||||||
|
callback(samples, n );
|
||||||
|
};
|
||||||
|
|
||||||
|
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
|
||||||
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback(
|
||||||
|
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
|
||||||
|
SherpaOnnxGeneratedAudioProgressCallback callback) {
|
||||||
|
auto wrapper = [callback](const float *samples, int32_t n, float progress) {
|
||||||
|
callback(samples, n, progress );
|
||||||
|
};
|
||||||
|
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
|
||||||
|
}
|
||||||
|
|
||||||
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
|
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
|
||||||
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
|
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
|
||||||
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
|
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
|
||||||
auto wrapper = [callback, arg](const float *samples, int32_t n) {
|
auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) {
|
||||||
callback(samples, n, arg);
|
callback(samples, n, arg);
|
||||||
};
|
};
|
||||||
|
|
||||||
sherpa_onnx::GeneratedAudio audio =
|
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
|
||||||
tts->impl->Generate(text, sid, speed, wrapper);
|
|
||||||
|
|
||||||
if (audio.samples.empty()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio;
|
|
||||||
|
|
||||||
float *samples = new float[audio.samples.size()];
|
|
||||||
std::copy(audio.samples.begin(), audio.samples.end(), samples);
|
|
||||||
|
|
||||||
ans->samples = samples;
|
|
||||||
ans->n = audio.samples.size();
|
|
||||||
ans->sample_rate = audio.sample_rate;
|
|
||||||
|
|
||||||
return ans;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
|
void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
|
||||||
|
|||||||
@@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
|
|||||||
typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
|
typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
|
||||||
int32_t n, void *arg);
|
int32_t n, void *arg);
|
||||||
|
|
||||||
|
typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,
|
||||||
|
int32_t n, float p);
|
||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
|
||||||
|
|
||||||
// Create an instance of offline TTS. The user has to use DestroyOfflineTts()
|
// Create an instance of offline TTS. The user has to use DestroyOfflineTts()
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
||||||
auto ans = Process(x, sid, speed);
|
auto ans = Process(x, sid, speed);
|
||||||
if (callback) {
|
if (callback) {
|
||||||
callback(ans.samples.data(), ans.samples.size());
|
callback(ans.samples.data(), ans.samples.size(), 1.0);
|
||||||
}
|
}
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
@@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||||
audio.samples.end());
|
audio.samples.end());
|
||||||
if (callback) {
|
if (callback) {
|
||||||
callback(audio.samples.data(), audio.samples.size());
|
callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches);
|
||||||
// Caution(fangjun): audio is freed when the callback returns, so users
|
// Caution(fangjun): audio is freed when the callback returns, so users
|
||||||
// should copy the data if they want to access the data after
|
// should copy the data if they want to access the data after
|
||||||
// the callback returns to avoid segmentation fault.
|
// the callback returns to avoid segmentation fault.
|
||||||
@@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||||
audio.samples.end());
|
audio.samples.end());
|
||||||
if (callback) {
|
if (callback) {
|
||||||
callback(audio.samples.data(), audio.samples.size());
|
callback(audio.samples.data(), audio.samples.size(), 1.0);
|
||||||
// Caution(fangjun): audio is freed when the callback returns, so users
|
// Caution(fangjun): audio is freed when the callback returns, so users
|
||||||
// should copy the data if they want to access the data after
|
// should copy the data if they want to access the data after
|
||||||
// the callback returns to avoid segmentation fault.
|
// the callback returns to avoid segmentation fault.
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ struct GeneratedAudio {
|
|||||||
class OfflineTtsImpl;
|
class OfflineTtsImpl;
|
||||||
|
|
||||||
using GeneratedAudioCallback =
|
using GeneratedAudioCallback =
|
||||||
std::function<void(const float * /*samples*/, int32_t /*n*/)>;
|
std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
|
||||||
|
|
||||||
class OfflineTts {
|
class OfflineTts {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) {
|
|||||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
|
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void AudioGeneratedCallback(const float *s, int32_t n) {
|
static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) {
|
||||||
if (n > 0) {
|
if (n > 0) {
|
||||||
Samples samples;
|
Samples samples;
|
||||||
samples.data = std::vector<float>{s, s + n};
|
samples.data = std::vector<float>{s, s + n};
|
||||||
|
|||||||
@@ -9,6 +9,11 @@
|
|||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||||
|
|
||||||
|
void audioCallback(const float *samples, int32_t n, float progress)
|
||||||
|
{
|
||||||
|
printf( "sample=%d, progress=%f\n", n, progress );
|
||||||
|
}
|
||||||
|
|
||||||
int main(int32_t argc, char *argv[]) {
|
int main(int32_t argc, char *argv[]) {
|
||||||
const char *kUsageMessage = R"usage(
|
const char *kUsageMessage = R"usage(
|
||||||
Offline text-to-speech with sherpa-onnx
|
Offline text-to-speech with sherpa-onnx
|
||||||
@@ -74,7 +79,7 @@ or details.
|
|||||||
sherpa_onnx::OfflineTts tts(config);
|
sherpa_onnx::OfflineTts tts(config);
|
||||||
|
|
||||||
const auto begin = std::chrono::steady_clock::now();
|
const auto begin = std::chrono::steady_clock::now();
|
||||||
auto audio = tts.Generate(po.GetArg(1), sid);
|
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
|
||||||
const auto end = std::chrono::steady_clock::now();
|
const auto end = std::chrono::steady_clock::now();
|
||||||
|
|
||||||
if (audio.samples.empty()) {
|
if (audio.samples.empty()) {
|
||||||
|
|||||||
@@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts {
|
|||||||
|
|
||||||
GeneratedAudio Generate(
|
GeneratedAudio Generate(
|
||||||
const std::string &text, int64_t sid = 0, float speed = 1.0,
|
const std::string &text, int64_t sid = 0, float speed = 1.0,
|
||||||
std::function<void(const float *, int32_t)> callback = nullptr) const {
|
std::function<void(const float *, int32_t, float)> callback = nullptr) const {
|
||||||
return tts_.Generate(text, sid, speed, callback);
|
return tts_.Generate(text, sid, speed, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
|
|||||||
const char *p_text = env->GetStringUTFChars(text, nullptr);
|
const char *p_text = env->GetStringUTFChars(text, nullptr);
|
||||||
SHERPA_ONNX_LOGE("string is: %s", p_text);
|
SHERPA_ONNX_LOGE("string is: %s", p_text);
|
||||||
|
|
||||||
std::function<void(const float *, int32_t)> callback_wrapper =
|
std::function<void(const float *, int32_t, float)> callback_wrapper =
|
||||||
[env, callback](const float *samples, int32_t n) {
|
[env, callback](const float *samples, int32_t n, float /*p*/) {
|
||||||
jclass cls = env->GetObjectClass(callback);
|
jclass cls = env->GetObjectClass(callback);
|
||||||
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
|
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
|
||||||
|
|
||||||
|
|||||||
@@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) {
|
|||||||
.def(
|
.def(
|
||||||
"generate",
|
"generate",
|
||||||
[](const PyClass &self, const std::string &text, int64_t sid,
|
[](const PyClass &self, const std::string &text, int64_t sid,
|
||||||
float speed, std::function<void(py::array_t<float>)> callback)
|
float speed, std::function<void(py::array_t<float>, float)> callback)
|
||||||
-> GeneratedAudio {
|
-> GeneratedAudio {
|
||||||
if (!callback) {
|
if (!callback) {
|
||||||
return self.Generate(text, sid, speed);
|
return self.Generate(text, sid, speed);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<void(const float *, int32_t)> callback_wrapper =
|
std::function<void(const float *, int32_t, float)> callback_wrapper =
|
||||||
[callback](const float *samples, int32_t n) {
|
[callback](const float *samples, int32_t n, float progress) {
|
||||||
// CAUTION(fangjun): we have to copy samples since it is
|
// CAUTION(fangjun): we have to copy samples since it is
|
||||||
// freed once the call back returns.
|
// freed once the call back returns.
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) {
|
|||||||
py::buffer_info buf = array.request();
|
py::buffer_info buf = array.request();
|
||||||
auto p = static_cast<float *>(buf.ptr);
|
auto p = static_cast<float *>(buf.ptr);
|
||||||
std::copy(samples, samples + n, p);
|
std::copy(samples, samples + n, p);
|
||||||
callback(array);
|
callback(array, progress);
|
||||||
};
|
};
|
||||||
|
|
||||||
return self.Generate(text, sid, speed, callback_wrapper);
|
return self.Generate(text, sid, speed, callback_wrapper);
|
||||||
|
|||||||
Reference in New Issue
Block a user