Added progress for callback of tts generator (#712)

Co-authored-by: leohwang <leohwang@360converter.com>
This commit is contained in:
Leo Huang
2024-03-28 17:12:20 +08:00
committed by GitHub
parent de655e838e
commit 638f48f47a
8 changed files with 51 additions and 40 deletions

View File

@@ -55,14 +55,14 @@ void PybindOfflineTts(py::module *m) {
.def(
"generate",
[](const PyClass &self, const std::string &text, int64_t sid,
float speed, std::function<void(py::array_t<float>)> callback)
float speed, std::function<void(py::array_t<float>, float)> callback)
-> GeneratedAudio {
if (!callback) {
return self.Generate(text, sid, speed);
}
std::function<void(const float *, int32_t)> callback_wrapper =
[callback](const float *samples, int32_t n) {
std::function<void(const float *, int32_t, float)> callback_wrapper =
[callback](const float *samples, int32_t n, float progress) {
// CAUTION(fangjun): we have to copy samples since it is
// freed once the call back returns.
@@ -72,7 +72,7 @@ void PybindOfflineTts(py::module *m) {
py::buffer_info buf = array.request();
auto p = static_cast<float *>(buf.ptr);
std::copy(samples, samples + n, p);
callback(array);
callback(array, progress);
};
return self.Generate(text, sid, speed, callback_wrapper);