Enable to stop TTS generation (#1041)
This commit is contained in:
@@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
|
||||
GeneratedAudio ans;
|
||||
|
||||
int32_t should_continue = 1;
|
||||
|
||||
int32_t k = 0;
|
||||
|
||||
for (int32_t b = 0; b != num_batches; ++b) {
|
||||
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
|
||||
batch.clear();
|
||||
for (int32_t i = 0; i != batch_size; ++i, ++k) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
@@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
if (callback) {
|
||||
callback(audio.samples.data(), audio.samples.size(),
|
||||
b * 1.0 / num_batches);
|
||||
should_continue = callback(audio.samples.data(), audio.samples.size(),
|
||||
b * 1.0 / num_batches);
|
||||
// Caution(fangjun): audio is freed when the callback returns, so users
|
||||
// should copy the data if they want to access the data after
|
||||
// the callback returns to avoid segmentation fault.
|
||||
@@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
|
||||
batch.clear();
|
||||
while (k < static_cast<int32_t>(x.size())) {
|
||||
while (k < static_cast<int32_t>(x.size()) && should_continue) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
++k;
|
||||
}
|
||||
|
||||
@@ -59,7 +59,9 @@ struct GeneratedAudio {
|
||||
|
||||
class OfflineTtsImpl;
|
||||
|
||||
using GeneratedAudioCallback = std::function<void(
|
||||
// If the callback returns 0, then it stop generating
|
||||
// if the callback returns 1, then it keeps generating
|
||||
using GeneratedAudioCallback = std::function<int32_t(
|
||||
const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
|
||||
|
||||
class OfflineTts {
|
||||
|
||||
@@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) {
|
||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
|
||||
}
|
||||
|
||||
static void AudioGeneratedCallback(const float *s, int32_t n,
|
||||
float /*progress*/) {
|
||||
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
|
||||
float /*progress*/) {
|
||||
if (n > 0) {
|
||||
std::lock_guard<std::mutex> lock(g_buffer.mutex);
|
||||
g_buffer.samples.push({s, s + n});
|
||||
g_cv.notify_all();
|
||||
}
|
||||
|
||||
if (g_killed) {
|
||||
return 0; // stop generating
|
||||
}
|
||||
|
||||
// continue generating
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void StartPlayback(const std::string &device_name, int32_t sample_rate) {
|
||||
|
||||
@@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) {
|
||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
|
||||
}
|
||||
|
||||
static void AudioGeneratedCallback(const float *s, int32_t n,
|
||||
float /*progress*/) {
|
||||
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
|
||||
float /*progress*/) {
|
||||
if (n > 0) {
|
||||
Samples samples;
|
||||
samples.data = std::vector<float>{s, s + n};
|
||||
@@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n,
|
||||
g_buffer.samples.push(std::move(samples));
|
||||
g_started = true;
|
||||
}
|
||||
if (g_killed) {
|
||||
return 0; // stop generating
|
||||
}
|
||||
|
||||
// continue generating
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int PlayCallback(const void * /*in*/, void *out,
|
||||
|
||||
@@ -9,8 +9,9 @@
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
|
||||
void audioCallback(const float * /*samples*/, int32_t n, float progress) {
|
||||
int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
|
||||
printf("sample=%d, progress=%f\n", n, progress);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
|
||||
Reference in New Issue
Block a user