Add streaming CTC ASR APIs for node-addon-api (#867)

This commit is contained in:
Fangjun Kuang
2024-05-13 11:58:25 +08:00
committed by GitHub
parent db85b2c1d8
commit 384f96c40f
15 changed files with 443 additions and 29 deletions

View File

@@ -89,6 +89,30 @@ static SherpaOnnxOnlineTransducerModelConfig GetOnlineTransducerModelConfig(
return config;
}
static SherpaOnnxOnlineZipformer2CtcModelConfig
GetOnlineZipformer2CtcModelConfig(Napi::Object obj) {
SherpaOnnxOnlineZipformer2CtcModelConfig config;
memset(&config, 0, sizeof(config));
if (!obj.Has("zipformer2Ctc") || !obj.Get("zipformer2Ctc").IsObject()) {
return config;
}
Napi::Object o = obj.Get("zipformer2Ctc").As<Napi::Object>();
if (o.Has("model") && o.Get("model").IsString()) {
Napi::String model = o.Get("model").As<Napi::String>();
std::string s = model.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.model = p;
}
return config;
}
static SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) {
SherpaOnnxOnlineModelConfig config;
memset(&config, 0, sizeof(config));
@@ -100,6 +124,7 @@ static SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) {
Napi::Object o = obj.Get("modelConfig").As<Napi::Object>();
config.transducer = GetOnlineTransducerModelConfig(o);
config.zipformer2_ctc = GetOnlineZipformer2CtcModelConfig(o);
if (o.Has("tokens") && o.Get("tokens").IsString()) {
Napi::String tokens = o.Get("tokens").As<Napi::String>();
@@ -147,6 +172,35 @@ static SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) {
return config;
}
static SherpaOnnxOnlineCtcFstDecoderConfig GetCtcFstDecoderConfig(
Napi::Object obj) {
SherpaOnnxOnlineCtcFstDecoderConfig config;
memset(&config, 0, sizeof(config));
if (!obj.Has("ctcFstDecoderConfig") ||
!obj.Get("ctcFstDecoderConfig").IsObject()) {
return config;
}
Napi::Object o = obj.Get("ctcFstDecoderConfig").As<Napi::Object>();
if (o.Has("graph") && o.Get("graph").IsString()) {
Napi::String graph = o.Get("graph").As<Napi::String>();
std::string s = graph.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.graph = p;
}
if (o.Has("maxActive") && o.Get("maxActive").IsNumber()) {
config.max_active = o.Get("maxActive").As<Napi::Number>().Int32Value();
}
return config;
}
static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
@@ -234,6 +288,8 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
config.Get("hotwordsScore").As<Napi::Number>().FloatValue();
}
c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(config);
#if 0
printf("encoder: %s\n", c.model_config.transducer.encoder
? c.model_config.transducer.encoder
@@ -277,6 +333,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
delete[] c.model_config.transducer.joiner;
}
if (c.model_config.zipformer2_ctc.model) {
delete[] c.model_config.zipformer2_ctc.model;
}
if (c.model_config.tokens) {
delete[] c.model_config.tokens;
}
@@ -297,6 +357,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
delete[] c.hotwords_file;
}
if (c.ctc_fst_decoder_config.graph) {
delete[] c.ctc_fst_decoder_config.graph;
}
if (!recognizer) {
Napi::TypeError::New(env, "Please check your config!")
.ThrowAsJavaScriptException();