Added progress for callback of tts generator (#712)

Co-authored-by: leohwang <leohwang@360converter.com>
This commit is contained in:
Leo Huang
2024-03-28 17:12:20 +08:00
committed by GitHub
parent de655e838e
commit 638f48f47a
8 changed files with 51 additions and 40 deletions

View File

@@ -134,7 +134,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
auto ans = Process(x, sid, speed);
if (callback) {
callback(ans.samples.data(), ans.samples.size());
callback(ans.samples.data(), ans.samples.size(), 1.0);
}
return ans;
}
@@ -168,7 +168,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size());
callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.
@@ -187,7 +187,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size());
callback(audio.samples.data(), audio.samples.size(), 1.0);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.

View File

@@ -55,7 +55,7 @@ struct GeneratedAudio {
class OfflineTtsImpl;
using GeneratedAudioCallback =
std::function<void(const float * /*samples*/, int32_t /*n*/)>;
std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
class OfflineTts {
public:

View File

@@ -47,7 +47,7 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n) {
static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};

View File

@@ -9,6 +9,11 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
void audioCallback(const float *samples, int32_t n, float progress)
{
printf( "sample=%d, progress=%f\n", n, progress );
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline text-to-speech with sherpa-onnx
@@ -74,7 +79,7 @@ or details.
sherpa_onnx::OfflineTts tts(config);
const auto begin = std::chrono::steady_clock::now();
auto audio = tts.Generate(po.GetArg(1), sid);
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
const auto end = std::chrono::steady_clock::now();
if (audio.samples.empty()) {