diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 46055ebe..14550fc8 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -807,16 +807,10 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) { return tts->impl->NumSpeakers(); } -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( +static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal( const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, - float speed) { - return SherpaOnnxOfflineTtsGenerateWithCallback(tts, text, sid, speed, - nullptr); -} - -const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( - const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, - SherpaOnnxGeneratedAudioCallback callback) { + float speed, std::function callback) +{ sherpa_onnx::GeneratedAudio audio = tts->impl->Generate(text, sid, speed, callback); @@ -836,30 +830,39 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( return ans; } +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate( + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, + float speed) { + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr ); +} + +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback( + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaOnnxGeneratedAudioCallback callback) { + auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) { + callback(samples, n ); + }; + + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); +} + +const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback( + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaOnnxGeneratedAudioProgressCallback callback) { + auto wrapper = [callback](const float *samples, int32_t n, float progress) { + callback(samples, n, progress ); + }; + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); +} + const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) { - auto wrapper = [callback, arg](const float *samples, int32_t n) { + auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) { callback(samples, n, arg); }; - sherpa_onnx::GeneratedAudio audio = - tts->impl->Generate(text, sid, speed, wrapper); - - if (audio.samples.empty()) { - return nullptr; - } - - SherpaOnnxGeneratedAudio *ans = new SherpaOnnxGeneratedAudio; - - float *samples = new float[audio.samples.size()]; - std::copy(audio.samples.begin(), audio.samples.end(), samples); - - ans->samples = samples; - ans->n = audio.samples.size(); - ans->sample_rate = audio.sample_rate; - - return ans; + return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper ); } void SherpaOnnxDestroyOfflineTtsGeneratedAudio( diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 570cb8e8..8c86f353 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -768,6 +768,9 @@ typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples, typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples, int32_t n, void *arg); +typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples, + int32_t n, float p); + SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; // Create an instance of offline TTS. The user has to use DestroyOfflineTts() diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 2fc0f309..cdd33e18 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { auto ans = Process(x, sid, speed); if (callback) { - callback(ans.samples.data(), ans.samples.size()); + callback(ans.samples.data(), ans.samples.size(), 1.0); } return ans; } @@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); if (callback) { - callback(audio.samples.data(), audio.samples.size()); + callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches); // Caution(fangjun): audio is freed when the callback returns, so users // should copy the data if they want to access the data after // the callback returns to avoid segmentation fault. @@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); if (callback) { - callback(audio.samples.data(), audio.samples.size()); + callback(audio.samples.data(), audio.samples.size(), 1.0); // Caution(fangjun): audio is freed when the callback returns, so users // should copy the data if they want to access the data after // the callback returns to avoid segmentation fault. diff --git a/sherpa-onnx/csrc/offline-tts.h b/sherpa-onnx/csrc/offline-tts.h index f67b20b5..c39dfdae 100644 --- a/sherpa-onnx/csrc/offline-tts.h +++ b/sherpa-onnx/csrc/offline-tts.h @@ -55,7 +55,7 @@ struct GeneratedAudio { class OfflineTtsImpl; using GeneratedAudioCallback = - std::function; + std::function; class OfflineTts { public: diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts-play.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts-play.cc index 00742c48..c6dee345 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts-play.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts-play.cc @@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) { fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); } -static void AudioGeneratedCallback(const float *s, int32_t n) { +static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) { if (n > 0) { Samples samples; samples.data = std::vector{s, s + n}; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 07de34de..aeab20ff 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -9,6 +9,11 @@ #include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/wave-writer.h" +void audioCallback(const float *samples, int32_t n, float progress) +{ + printf( "sample=%d, progress=%f\n", n, progress ); +} + int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( Offline text-to-speech with sherpa-onnx @@ -74,7 +79,7 @@ or details. sherpa_onnx::OfflineTts tts(config); const auto begin = std::chrono::steady_clock::now(); - auto audio = tts.Generate(po.GetArg(1), sid); + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); const auto end = std::chrono::steady_clock::now(); if (audio.samples.empty()) { diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index ee027c4c..5d874bc6 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -797,7 +797,7 @@ class SherpaOnnxOfflineTts { GeneratedAudio Generate( const std::string &text, int64_t sid = 0, float speed = 1.0, - std::function callback = nullptr) const { + std::function callback = nullptr) const { return tts_.Generate(text, sid, speed, callback); } @@ -1314,8 +1314,8 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( const char *p_text = env->GetStringUTFChars(text, nullptr); SHERPA_ONNX_LOGE("string is: %s", p_text); - std::function callback_wrapper = - [env, callback](const float *samples, int32_t n) { + std::function callback_wrapper = + [env, callback](const float *samples, int32_t n, float /*p*/) { jclass cls = env->GetObjectClass(callback); jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V"); diff --git a/sherpa-onnx/python/csrc/offline-tts.cc b/sherpa-onnx/python/csrc/offline-tts.cc index 144001b0..82006330 100644 --- a/sherpa-onnx/python/csrc/offline-tts.cc +++ b/sherpa-onnx/python/csrc/offline-tts.cc @@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) { .def( "generate", [](const PyClass &self, const std::string &text, int64_t sid, - float speed, std::function)> callback) + float speed, std::function, float)> callback) -> GeneratedAudio { if (!callback) { return self.Generate(text, sid, speed); } - std::function callback_wrapper = - [callback](const float *samples, int32_t n) { + std::function callback_wrapper = + [callback](const float *samples, int32_t n, float progress) { // CAUTION(fangjun): we have to copy samples since it is // freed once the call back returns. @@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) { py::buffer_info buf = array.request(); auto p = static_cast(buf.ptr); std::copy(samples, samples + n, p); - callback(array); + callback(array, progress); }; return self.Generate(text, sid, speed, callback_wrapper);