Added progress for callback of tts generator (#712)

Co-authored-by: leohwang <leohwang@360converter.com>
This commit is contained in:
Leo Huang
2024-03-28 17:12:20 +08:00
committed by GitHub
parent de655e838e
commit 638f48f47a
8 changed files with 51 additions and 40 deletions

View File

@@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
return tts->impl->NumSpeakers(); return tts->impl->NumSpeakers();
} }
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed) { float speed, std::function<void(const float *, int32_t, float)> callback)
return SherpaOnnxOfflineTtsGenerateWithCallback(tts, text, sid, speed, {
nullptr);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallback callback) {
sherpa_onnx::GeneratedAudio audio = sherpa_onnx::GeneratedAudio audio =
tts->impl->Generate(text, sid, speed, callback); tts->impl->Generate(text, sid, speed, callback);
@@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
return ans; return ans;
} }
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed) {
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr );
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) {
callback(samples, n );
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioProgressCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float progress) {
callback(samples, n, progress );
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
auto wrapper = [callback, arg](const float *samples, int32_t n) { auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) {
callback(samples, n, arg); callback(samples, n, arg);
}; };
sherpa_onnx::GeneratedAudio audio = return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
tts->impl->Generate(text, sid, speed, wrapper);
if (audio.samples.empty()) {
return nullptr;
}
SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio;
float *samples = new float[audio.samples.size()];
std::copy(audio.samples.begin(), audio.samples.end(), samples);
ans->samples = samples;
ans->n = audio.samples.size();
ans->sample_rate = audio.sample_rate;
return ans;
} }
void SherpaOnnxDestroyOfflineTtsGeneratedAudio( void SherpaOnnxDestroyOfflineTtsGeneratedAudio(

View File

@@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
int32_t n, void *arg); int32_t n, void *arg);
typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,
int32_t n, float p);
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
// Create an instance of offline TTS. The user has to use DestroyOfflineTts() // Create an instance of offline TTS. The user has to use DestroyOfflineTts()

View File

@@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
auto ans = Process(x, sid, speed); auto ans = Process(x, sid, speed);
if (callback) { if (callback) {
callback(ans.samples.data(), ans.samples.size()); callback(ans.samples.data(), ans.samples.size(), 1.0);
} }
return ans; return ans;
} }
@@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(), ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end()); audio.samples.end());
if (callback) { if (callback) {
callback(audio.samples.data(), audio.samples.size()); callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users // Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after // should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault. // the callback returns to avoid segmentation fault.
@@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(), ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end()); audio.samples.end());
if (callback) { if (callback) {
callback(audio.samples.data(), audio.samples.size()); callback(audio.samples.data(), audio.samples.size(), 1.0);
// Caution(fangjun): audio is freed when the callback returns, so users // Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after // should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault. // the callback returns to avoid segmentation fault.

View File

@@ -55,7 +55,7 @@ struct GeneratedAudio {
class OfflineTtsImpl; class OfflineTtsImpl;
using GeneratedAudioCallback = using GeneratedAudioCallback =
std::function<void(const float * /*samples*/, int32_t /*n*/)>; std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
class OfflineTts { class OfflineTts {
public: public:

View File

@@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
} }
static void AudioGeneratedCallback(const float *s, int32_t n) { static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) {
if (n > 0) { if (n > 0) {
Samples samples; Samples samples;
samples.data = std::vector<float>{s, s + n}; samples.data = std::vector<float>{s, s + n};

View File

@@ -9,6 +9,11 @@
#include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h" #include "sherpa-onnx/csrc/wave-writer.h"
void audioCallback(const float *samples, int32_t n, float progress)
{
printf( "sample=%d, progress=%f\n", n, progress );
}
int main(int32_t argc, char *argv[]) { int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage( const char *kUsageMessage = R"usage(
Offline text-to-speech with sherpa-onnx Offline text-to-speech with sherpa-onnx
@@ -74,7 +79,7 @@ or details.
sherpa_onnx::OfflineTts tts(config); sherpa_onnx::OfflineTts tts(config);
const auto begin = std::chrono::steady_clock::now(); const auto begin = std::chrono::steady_clock::now();
auto audio = tts.Generate(po.GetArg(1), sid); auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
const auto end = std::chrono::steady_clock::now(); const auto end = std::chrono::steady_clock::now();
if (audio.samples.empty()) { if (audio.samples.empty()) {

View File

@@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts {
GeneratedAudio Generate( GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0, const std::string &text, int64_t sid = 0, float speed = 1.0,
std::function<void(const float *, int32_t)> callback = nullptr) const { std::function<void(const float *, int32_t, float)> callback = nullptr) const {
return tts_.Generate(text, sid, speed, callback); return tts_.Generate(text, sid, speed, callback);
} }
@@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
const char *p_text = env->GetStringUTFChars(text, nullptr); const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text); SHERPA_ONNX_LOGE("string is: %s", p_text);
std::function<void(const float *, int32_t)> callback_wrapper = std::function<void(const float *, int32_t, float)> callback_wrapper =
[env, callback](const float *samples, int32_t n) { [env, callback](const float *samples, int32_t n, float /*p*/) {
jclass cls = env->GetObjectClass(callback); jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");

View File

@@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) {
.def( .def(
"generate", "generate",
[](const PyClass &self, const std::string &text, int64_t sid, [](const PyClass &self, const std::string &text, int64_t sid,
float speed, std::function<void(py::array_t<float>)> callback) float speed, std::function<void(py::array_t<float>, float)> callback)
-> GeneratedAudio { -> GeneratedAudio {
if (!callback) { if (!callback) {
return self.Generate(text, sid, speed); return self.Generate(text, sid, speed);
} }
std::function<void(const float *, int32_t)> callback_wrapper = std::function<void(const float *, int32_t, float)> callback_wrapper =
[callback](const float *samples, int32_t n) { [callback](const float *samples, int32_t n, float progress) {
// CAUTION(fangjun): we have to copy samples since it is // CAUTION(fangjun): we have to copy samples since it is
// freed once the call back returns. // freed once the call back returns.
@@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) {
py::buffer_info buf = array.request(); py::buffer_info buf = array.request();
auto p = static_cast<float *>(buf.ptr); auto p = static_cast<float *>(buf.ptr);
std::copy(samples, samples + n, p); std::copy(samples, samples + n, p);
callback(array); callback(array, progress);
}; };
return self.Generate(text, sid, speed, callback_wrapper); return self.Generate(text, sid, speed, callback_wrapper);