Enable to stop TTS generation (#1041)

This commit is contained in:
Fangjun Kuang
2024-06-22 18:18:36 +08:00
committed by GitHub
parent 96ab843173
commit 9dd0e03568
32 changed files with 249 additions and 70 deletions

View File

@@ -935,7 +935,7 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
std::function<void(const float *, int32_t, float)> callback) {
std::function<int32_t(const float *, int32_t, float)> callback) {
sherpa_onnx::GeneratedAudio audio =
tts->impl->Generate(text, sid, speed, callback);
@@ -965,7 +965,9 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n,
float /*progress*/) { callback(samples, n); };
float /*progress*/) {
return callback(samples, n);
};
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
@@ -975,7 +977,7 @@ SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioProgressCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float progress) {
callback(samples, n, progress);
return callback(samples, n, progress);
};
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
@@ -985,7 +987,7 @@ const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
auto wrapper = [callback, arg](const float *samples, int32_t n,
float /*progress*/) {
callback(samples, n, arg);
return callback(samples, n, arg);
};
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);

View File

@@ -850,14 +850,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio {
int32_t sample_rate;
} SherpaOnnxGeneratedAudio;
typedef void (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
int32_t n);
// If the callback returns 0, then it stops generating
// If the callback returns 1, then it keeps generating
typedef int32_t (*SherpaOnnxGeneratedAudioCallback)(const float *samples,
int32_t n);
typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
int32_t n, void *arg);
typedef int32_t (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
int32_t n,
void *arg);
typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,
int32_t n, float p);
typedef int32_t (*SherpaOnnxGeneratedAudioProgressCallback)(
const float *samples, int32_t n, float p);
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;

View File

@@ -216,9 +216,11 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
GeneratedAudio ans;
int32_t should_continue = 1;
int32_t k = 0;
for (int32_t b = 0; b != num_batches; ++b) {
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch.push_back(std::move(x[k]));
@@ -229,8 +231,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size(),
b * 1.0 / num_batches);
should_continue = callback(audio.samples.data(), audio.samples.size(),
b * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.
@@ -238,7 +240,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
batch.clear();
while (k < static_cast<int32_t>(x.size())) {
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch.push_back(std::move(x[k]));
++k;
}

View File

@@ -59,7 +59,9 @@ struct GeneratedAudio {
class OfflineTtsImpl;
using GeneratedAudioCallback = std::function<void(
// If the callback returns 0, then it stop generating
// if the callback returns 1, then it keeps generating
using GeneratedAudioCallback = std::function<int32_t(
const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
class OfflineTts {

View File

@@ -44,13 +44,20 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
std::lock_guard<std::mutex> lock(g_buffer.mutex);
g_buffer.samples.push({s, s + n});
g_cv.notify_all();
}
if (g_killed) {
return 0; // stop generating
}
// continue generating
return 1;
}
static void StartPlayback(const std::string &device_name, int32_t sample_rate) {

View File

@@ -47,8 +47,8 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
static int32_t AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};
@@ -57,6 +57,12 @@ static void AudioGeneratedCallback(const float *s, int32_t n,
g_buffer.samples.push(std::move(samples));
g_started = true;
}
if (g_killed) {
return 0; // stop generating
}
// continue generating
return 1;
}
static int PlayCallback(const void * /*in*/, void *out,

View File

@@ -9,8 +9,9 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
void audioCallback(const float * /*samples*/, int32_t n, float progress) {
int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
printf("sample=%d, progress=%f\n", n, progress);
return 1;
}
int main(int32_t argc, char *argv[]) {

View File

@@ -1,3 +1,7 @@
## 1.10.1
* Enable to stop TTS generation
## 1.10.0
* Add inverse text normalization

View File

@@ -326,7 +326,7 @@ typedef SherpaOnnxDestroyOfflineTtsGeneratedAudioNative = Void Function(
typedef SherpaOnnxDestroyOfflineTtsGeneratedAudio = void Function(
Pointer<SherpaOnnxGeneratedAudio>);
typedef SherpaOnnxGeneratedAudioCallbackNative = Void Function(
typedef SherpaOnnxGeneratedAudioCallbackNative = Int Function(
Pointer<Float>, Int32);
typedef SherpaOnnxOfflineTtsGenerateWithCallbackNative

View File

@@ -149,7 +149,7 @@ class OfflineTts {
{required String text,
int sid = 0,
double speed = 1.0,
required void Function(Float32List samples) callback}) {
required int Function(Float32List samples) callback}) {
// see
// https://github.com/dart-lang/sdk/issues/54276#issuecomment-1846109285
// https://stackoverflow.com/questions/69537440/callbacks-in-dart-dartffi-only-supports-calling-static-dart-functions-from-nat
@@ -159,8 +159,8 @@ class OfflineTts {
(Pointer<Float> samples, int n) {
final s = samples.asTypedList(n);
final newSamples = Float32List.fromList(s);
callback(newSamples);
});
return callback(newSamples);
}, exceptionalReturn: 0);
final Pointer<Utf8> textPtr = text.toNativeUtf8();
final p = SherpaOnnxBindings.offlineTtsGenerateWithCallback

View File

@@ -186,14 +186,42 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
const char *p_text = env->GetStringUTFChars(text, nullptr);
SHERPA_ONNX_LOGE("string is: %s", p_text);
std::function<void(const float *, int32_t, float)> callback_wrapper =
std::function<int32_t(const float *, int32_t, float)> callback_wrapper =
[env, callback](const float *samples, int32_t n, float /*progress*/) {
jclass cls = env->GetObjectClass(callback);
jmethodID mid = env->GetMethodID(cls, "invoke", "([F)V");
#if 0
// this block is for debugging only
// see also
// https://jnjosh.com/posts/kotlinfromcpp/
jmethodID classMethodId =
env->GetMethodID(cls, "getClass", "()Ljava/lang/Class;");
jobject klassObj = env->CallObjectMethod(callback, classMethodId);
auto klassObject = env->GetObjectClass(klassObj);
auto nameMethodId =
env->GetMethodID(klassObject, "getName", "()Ljava/lang/String;");
jstring classString =
(jstring)env->CallObjectMethod(klassObj, nameMethodId);
auto className = env->GetStringUTFChars(classString, NULL);
SHERPA_ONNX_LOGE("name is: %s", className);
env->ReleaseStringUTFChars(classString, className);
#endif
jmethodID mid =
env->GetMethodID(cls, "invoke", "([F)Ljava/lang/Integer;");
if (mid == nullptr) {
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
return 1;
}
jfloatArray samples_arr = env->NewFloatArray(n);
env->SetFloatArrayRegion(samples_arr, 0, n, samples);
env->CallVoidMethod(callback, mid, samples_arr);
jobject should_continue =
env->CallObjectMethod(callback, mid, samples_arr);
jclass jklass = env->GetObjectClass(should_continue);
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
return env->CallIntMethod(should_continue, int_value_mid);
};
auto audio = reinterpret_cast<sherpa_onnx::OfflineTts *>(ptr)->Generate(

View File

@@ -57,13 +57,13 @@ void PybindOfflineTts(py::module *m) {
"generate",
[](const PyClass &self, const std::string &text, int64_t sid,
float speed,
std::function<void(py::array_t<float>, float)> callback)
std::function<int32_t(py::array_t<float>, float)> callback)
-> GeneratedAudio {
if (!callback) {
return self.Generate(text, sid, speed);
}
std::function<void(const float *, int32_t, float)>
std::function<int32_t(const float *, int32_t, float)>
callback_wrapper = [callback](const float *samples, int32_t n,
float progress) {
// CAUTION(fangjun): we have to copy samples since it is
@@ -75,7 +75,7 @@ void PybindOfflineTts(py::module *m) {
py::buffer_info buf = array.request();
auto p = static_cast<float *>(buf.ptr);
std::copy(samples, samples + n, p);
callback(array, progress);
return callback(array, progress);
};
return self.Generate(text, sid, speed, callback_wrapper);