Add on-device tex-to-speech (TTS) demo for HarmonyOS (#1590)

This commit is contained in:
Fangjun Kuang
2024-12-04 14:27:12 +08:00
committed by GitHub
parent 47a2dd4cf8
commit 74a8735f7a
61 changed files with 1930 additions and 117 deletions

View File

@@ -1,4 +1,8 @@
export { readWave, readWaveFromBinary } from "libsherpa_onnx.so";
export {
listRawfileDir,
readWave,
readWaveFromBinary,
} from "libsherpa_onnx.so";
export {
CircularBuffer,

View File

@@ -4,7 +4,7 @@
"externalNativeOptions": {
"path": "./src/main/cpp/CMakeLists.txt",
"arguments": "",
"cppFlags": "",
"cppFlags": "-std=c++17",
"abiFilters": [
"arm64-v8a",
"x86_64",

View File

@@ -2,6 +2,10 @@
cmake_minimum_required(VERSION 3.13.0)
project(myNpmLib)
if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ version to use")
endif()
# Disable warning about
#
# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is
@@ -46,6 +50,7 @@ add_library(sherpa_onnx SHARED
speaker-identification.cc
spoken-language-identification.cc
streaming-asr.cc
utils.cc
vad.cc
wave-reader.cc
wave-writer.cc

View File

@@ -213,12 +213,13 @@ static Napi::Number OfflineTtsNumSpeakersWrapper(
return Napi::Number::New(env, num_speakers);
}
// synchronous version
static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
@@ -298,8 +299,8 @@ static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
int32_t sid = obj.Get("sid").As<Napi::Number>().Int32Value();
float speed = obj.Get("speed").As<Napi::Number>().FloatValue();
const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerate(tts, text.c_str(), sid, speed);
const SherpaOnnxGeneratedAudio *audio;
audio = SherpaOnnxOfflineTtsGenerate(tts, text.c_str(), sid, speed);
if (enable_external_buffer) {
Napi::ArrayBuffer arrayBuffer = Napi::ArrayBuffer::New(
@@ -334,6 +335,256 @@ static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
}
}
struct TtsCallbackData {
std::vector<float> samples;
float progress;
bool processed = false;
bool cancelled = false;
};
// see
// https://github.com/nodejs/node-addon-examples/blob/main/src/6-threadsafe-function/typed_threadsafe_function/node-addon-api/clock.cc
void InvokeJsCallback(Napi::Env env, Napi::Function callback,
Napi::Reference<Napi::Value> *context,
TtsCallbackData *data) {
if (env != nullptr) {
if (callback != nullptr) {
Napi::ArrayBuffer arrayBuffer =
Napi::ArrayBuffer::New(env, sizeof(float) * data->samples.size());
Napi::Float32Array float32Array =
Napi::Float32Array::New(env, data->samples.size(), arrayBuffer, 0);
std::copy(data->samples.begin(), data->samples.end(),
float32Array.Data());
Napi::Object arg = Napi::Object::New(env);
arg.Set(Napi::String::New(env, "samples"), float32Array);
arg.Set(Napi::String::New(env, "progress"), data->progress);
auto v = callback.Call(context->Value(), {arg});
data->processed = true;
if (v.IsNumber() && v.As<Napi::Number>().Int32Value()) {
data->cancelled = false;
} else {
data->cancelled = true;
}
}
}
}
using TSFN = Napi::TypedThreadSafeFunction<Napi::Reference<Napi::Value>,
TtsCallbackData, InvokeJsCallback>;
class TtsGenerateWorker : public Napi::AsyncWorker {
public:
TtsGenerateWorker(const Napi::Env &env, TSFN tsfn, SherpaOnnxOfflineTts *tts,
const std::string &text, float speed, int32_t sid,
bool use_external_buffer)
: tsfn_(tsfn),
Napi::AsyncWorker{env, "TtsGenerateWorker"},
deferred_(env),
tts_(tts),
text_(text),
speed_(speed),
sid_(sid),
use_external_buffer_(use_external_buffer) {}
Napi::Promise Promise() { return deferred_.Promise(); }
~TtsGenerateWorker() {
for (auto d : data_list_) {
delete d;
}
}
protected:
void Execute() override {
auto callback = [](const float *samples, int32_t n, float progress,
void *arg) -> int32_t {
TtsGenerateWorker *_this = reinterpret_cast<TtsGenerateWorker *>(arg);
for (auto d : _this->data_list_) {
if (d->cancelled) {
OH_LOG_INFO(LOG_APP, "TtsGenerate is cancelled");
return 0;
}
}
auto data = new TtsCallbackData;
data->samples = std::vector<float>{samples, samples + n};
data->progress = progress;
_this->data_list_.push_back(data);
_this->tsfn_.NonBlockingCall(data);
return 1;
};
audio_ = SherpaOnnxOfflineTtsGenerateWithProgressCallbackWithArg(
tts_, text_.c_str(), sid_, speed_, callback, this);
tsfn_.Release();
}
void OnOK() override {
Napi::Env env = deferred_.Env();
Napi::Object ans = Napi::Object::New(env);
if (use_external_buffer_) {
Napi::ArrayBuffer arrayBuffer = Napi::ArrayBuffer::New(
env, const_cast<float *>(audio_->samples), sizeof(float) * audio_->n,
[](Napi::Env /*env*/, void * /*data*/,
const SherpaOnnxGeneratedAudio *hint) {
SherpaOnnxDestroyOfflineTtsGeneratedAudio(hint);
},
audio_);
Napi::Float32Array float32Array =
Napi::Float32Array::New(env, audio_->n, arrayBuffer, 0);
ans.Set(Napi::String::New(env, "samples"), float32Array);
ans.Set(Napi::String::New(env, "sampleRate"), audio_->sample_rate);
} else {
// don't use external buffer
Napi::ArrayBuffer arrayBuffer =
Napi::ArrayBuffer::New(env, sizeof(float) * audio_->n);
Napi::Float32Array float32Array =
Napi::Float32Array::New(env, audio_->n, arrayBuffer, 0);
std::copy(audio_->samples, audio_->samples + audio_->n,
float32Array.Data());
ans.Set(Napi::String::New(env, "samples"), float32Array);
ans.Set(Napi::String::New(env, "sampleRate"), audio_->sample_rate);
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio_);
}
deferred_.Resolve(ans);
}
private:
TSFN tsfn_;
Napi::Promise::Deferred deferred_;
SherpaOnnxOfflineTts *tts_;
std::string text_;
float speed_;
int32_t sid_;
bool use_external_buffer_;
const SherpaOnnxGeneratedAudio *audio_;
std::vector<TtsCallbackData *> data_list_;
};
static Napi::Object OfflineTtsGenerateAsyncWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsExternal()) {
Napi::TypeError::New(env, "Argument 0 should be an offline tts pointer.")
.ThrowAsJavaScriptException();
return {};
}
SherpaOnnxOfflineTts *tts =
info[0].As<Napi::External<SherpaOnnxOfflineTts>>().Data();
if (!info[1].IsObject()) {
Napi::TypeError::New(env, "Argument 1 should be an object")
.ThrowAsJavaScriptException();
return {};
}
Napi::Object obj = info[1].As<Napi::Object>();
if (!obj.Has("text")) {
Napi::TypeError::New(env, "The argument object should have a field text")
.ThrowAsJavaScriptException();
return {};
}
if (!obj.Get("text").IsString()) {
Napi::TypeError::New(env, "The object['text'] should be a string")
.ThrowAsJavaScriptException();
return {};
}
if (!obj.Has("sid")) {
Napi::TypeError::New(env, "The argument object should have a field sid")
.ThrowAsJavaScriptException();
return {};
}
if (!obj.Get("sid").IsNumber()) {
Napi::TypeError::New(env, "The object['sid'] should be a number")
.ThrowAsJavaScriptException();
return {};
}
if (!obj.Has("speed")) {
Napi::TypeError::New(env, "The argument object should have a field speed")
.ThrowAsJavaScriptException();
return {};
}
if (!obj.Get("speed").IsNumber()) {
Napi::TypeError::New(env, "The object['speed'] should be a number")
.ThrowAsJavaScriptException();
return {};
}
bool enable_external_buffer = true;
if (obj.Has("enableExternalBuffer") &&
obj.Get("enableExternalBuffer").IsBoolean()) {
enable_external_buffer =
obj.Get("enableExternalBuffer").As<Napi::Boolean>().Value();
}
Napi::String _text = obj.Get("text").As<Napi::String>();
std::string text = _text.Utf8Value();
int32_t sid = obj.Get("sid").As<Napi::Number>().Int32Value();
float speed = obj.Get("speed").As<Napi::Number>().FloatValue();
Napi::Function cb;
if (obj.Has("callback") && obj.Get("callback").IsFunction()) {
cb = obj.Get("callback").As<Napi::Function>();
}
auto context =
new Napi::Reference<Napi::Value>(Napi::Persistent(info.This()));
TSFN tsfn = TSFN::New(
env,
cb, // JavaScript function called asynchronously
"TtsGenerateFunc", // Name
0, // Unlimited queue
1, // Only one thread will use this initially
context,
[](Napi::Env, void *, Napi::Reference<Napi::Value> *ctx) { delete ctx; });
const SherpaOnnxGeneratedAudio *audio;
TtsGenerateWorker *worker = new TtsGenerateWorker(
env, tsfn, tts, text, speed, sid, enable_external_buffer);
worker->Queue();
return worker->Promise();
}
void InitNonStreamingTts(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "createOfflineTts"),
Napi::Function::New(env, CreateOfflineTtsWrapper));
@@ -346,4 +597,7 @@ void InitNonStreamingTts(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "offlineTtsGenerate"),
Napi::Function::New(env, OfflineTtsGenerateWrapper));
exports.Set(Napi::String::New(env, "offlineTtsGenerateAsync"),
Napi::Function::New(env, OfflineTtsGenerateAsyncWrapper));
}

View File

@@ -27,6 +27,10 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports);
void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports);
#if __OHOS__
void InitUtils(Napi::Env env, Napi::Object exports);
#endif
Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitStreamingAsr(env, exports);
InitNonStreamingAsr(env, exports);
@@ -41,7 +45,15 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitKeywordSpotting(env, exports);
InitNonStreamingSpeakerDiarization(env, exports);
#if __OHOS__
InitUtils(env, exports);
#endif
return exports;
}
#if __OHOS__
NODE_API_MODULE(sherpa_onnx, Init)
#else
NODE_API_MODULE(addon, Init)
#endif

View File

@@ -1,3 +1,5 @@
export const listRawfileDir: (mgr: object, dir: string) => Array<string>;
export const readWave: (filename: string, enableExternalBuffer: boolean = true) => {samples: Float32Array, sampleRate: number};
export const readWaveFromBinary: (data: Uint8Array, enableExternalBuffer: boolean = true) => {samples: Float32Array, sampleRate: number};
export const createCircularBuffer: (capacity: number) => object;
@@ -37,4 +39,11 @@ export const getOnlineStreamResultAsJson: (handle: object, streamHandle: object)
export const createOfflineTts: (config: object, mgr?: object) => object;
export const getOfflineTtsNumSpeakers: (handle: object) => number;
export const getOfflineTtsSampleRate: (handle: object) => number;
export const offlineTtsGenerate: (handle: object, input: object) => object;
export type TtsOutput = {
samples: Float32Array;
sampleRate: number;
};
export const offlineTtsGenerate: (handle: object, input: object) => TtsOutput;
export const offlineTtsGenerateAsync: (handle: object, input: object) => Promise<TtsOutput>;

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2024 Xiaomi Corporation
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "macros.h" // NOLINT
#include "napi.h" // NOLINT
static std::vector<std::string> GetFilenames(NativeResourceManager *mgr,
const std::string &d) {
std::unique_ptr<RawDir, decltype(&OH_ResourceManager_CloseRawDir)> raw_dir(
OH_ResourceManager_OpenRawDir(mgr, d.c_str()),
&OH_ResourceManager_CloseRawDir);
int count = OH_ResourceManager_GetRawFileCount(raw_dir.get());
std::vector<std::string> ans;
ans.reserve(count);
for (int32_t i = 0; i < count; ++i) {
std::string filename = OH_ResourceManager_GetRawFileName(raw_dir.get(), i);
bool is_dir = OH_ResourceManager_IsRawDir(
mgr, d.empty() ? filename.c_str() : (d + "/" + filename).c_str());
if (is_dir) {
auto files = GetFilenames(mgr, d.empty() ? filename : d + "/" + filename);
for (auto &f : files) {
ans.push_back(std::move(f));
}
} else {
if (d.empty()) {
ans.push_back(std::move(filename));
} else {
ans.push_back(d + "/" + filename);
}
}
}
return ans;
}
static Napi::Array ListRawFileDir(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
std::unique_ptr<NativeResourceManager,
decltype(&OH_ResourceManager_ReleaseNativeResourceManager)>
mgr(OH_ResourceManager_InitNativeResourceManager(env, info[0]),
&OH_ResourceManager_ReleaseNativeResourceManager);
if (!info[1].IsString()) {
Napi::TypeError::New(env, "Argument 1 should be a string")
.ThrowAsJavaScriptException();
return {};
}
std::string dir = info[1].As<Napi::String>().Utf8Value();
auto files = GetFilenames(mgr.get(), dir);
Napi::Array ans = Napi::Array::New(env, files.size());
for (int32_t i = 0; i != files.size(); ++i) {
ans[i] = Napi::String::New(env, files[i]);
}
return ans;
}
void InitUtils(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "listRawfileDir"),
Napi::Function::New(env, ListRawFileDir));
}

View File

@@ -3,6 +3,7 @@ import {
getOfflineTtsNumSpeakers,
getOfflineTtsSampleRate,
offlineTtsGenerate,
offlineTtsGenerateAsync,
} from "libsherpa_onnx.so";
export class OfflineTtsVitsModelConfig {
@@ -16,14 +17,14 @@ export class OfflineTtsVitsModelConfig {
public lengthScale: number = 1.0;
}
export class OfflineTtsModelConfig{
export class OfflineTtsModelConfig {
public vits: OfflineTtsVitsModelConfig = new OfflineTtsVitsModelConfig();
public numThreads: number = 1;
public debug: boolean = false;
public provider: string = 'cpu';
}
export class OfflineTtsConfig{
export class OfflineTtsConfig {
public model: OfflineTtsModelConfig = new OfflineTtsModelConfig();
public ruleFsts: string = '';
public ruleFars: string = '';
@@ -35,17 +36,24 @@ export class TtsOutput {
public sampleRate: number = 0;
}
interface TtsCallbackData {
samples: Float32Array;
progress: number;
}
export class TtsInput {
public text: string = '';
public sid: number = 0;
public speed: number = 1.0;
public callback?: (data: TtsCallbackData) => number;
}
export class OfflineTts {
private handle: object;
public config: OfflineTtsConfig;
public numSpeakers: number;
public sampleRate: number;
private handle: object;
constructor(config: OfflineTtsConfig, mgr?: object) {
this.handle = createOfflineTts(config, mgr);
this.config = config;
@@ -63,4 +71,8 @@ export class OfflineTts {
generate(input: TtsInput): TtsOutput {
return offlineTtsGenerate(this.handle, input) as TtsOutput;
}
generateAsync(input: TtsInput): Promise<TtsOutput> {
return offlineTtsGenerateAsync(this.handle, input);
}
}

View File

@@ -57,7 +57,6 @@ export class CircularBuffer {
// samples is a float32 array
push(samples: Float32Array) {
console.log(`here samples: ${samples}`);
circularBufferPush(this.handle, samples);
}