Enable to stop TTS generation (#1041)

This commit is contained in:
Fangjun Kuang
2024-06-22 18:18:36 +08:00
committed by GitHub
parent 96ab843173
commit 9dd0e03568
32 changed files with 249 additions and 70 deletions

View File

@@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
GeneratedAudio ans;
int32_t should_continue = 1;
int32_t k = 0;
for (int32_t b = 0; b != num_batches; ++b) {
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch.push_back(std::move(x[k]));
@@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size(),
b * 1.0 / num_batches);
should_continue = callback(audio.samples.data(), audio.samples.size(),
b * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.
@@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
batch.clear();
while (k < static_cast<int32_t>(x.size())) {
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch.push_back(std::move(x[k]));
++k;
}

View File

@@ -59,7 +59,9 @@ struct GeneratedAudio {
class OfflineTtsImpl;
using GeneratedAudioCallback = std::function<void(
// If the callback returns 0, then it stop generating
// if the callback returns 1, then it keeps generating
using GeneratedAudioCallback = std::function<int32_t(
const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
class OfflineTts {

View File

@@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
std::lock_guard<std::mutex> lock(g_buffer.mutex);
g_buffer.samples.push({s, s + n});
g_cv.notify_all();
}
if (g_killed) {
return 0; // stop generating
}
// continue generating
return 1;
}
static void StartPlayback(const std::string &device_name, int32_t sample_rate) {

View File

@@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};
@@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n,
g_buffer.samples.push(std::move(samples));
g_started = true;
}
if (g_killed) {
return 0; // stop generating
}
// continue generating
return 1;
}
static int PlayCallback(const void * /*in*/, void *out,

View File

@@ -9,8 +9,9 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
void audioCallback(const float * /*samples*/, int32_t n, float progress) {
int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
printf("sample=%d, progress=%f\n", n, progress);
return 1;
}
int main(int32_t argc, char *argv[]) {