Refactor node-addon-api to remove duplicate. (#873)

This commit is contained in:
Fangjun Kuang
2024-05-14 10:08:11 +08:00
committed by GitHub
parent 939fdd942c
commit 0895b64850
7 changed files with 191 additions and 570 deletions

View File

@@ -3,7 +3,8 @@
// Copyright (c) 2024 Xiaomi Corporation
#include <sstream>
#include "napi.h" // NOLINT
#include "macros.h" // NOLINT
#include "napi.h" // NOLINT
#include "sherpa-onnx/c-api/c-api.h"
/*
{
@@ -14,26 +15,19 @@
};
*/
SherpaOnnxFeatureConfig GetFeatureConfig(Napi::Object obj) {
SherpaOnnxFeatureConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxFeatureConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("featConfig") || !obj.Get("featConfig").IsObject()) {
return config;
return c;
}
Napi::Object featConfig = obj.Get("featConfig").As<Napi::Object>();
Napi::Object o = obj.Get("featConfig").As<Napi::Object>();
if (featConfig.Has("sampleRate") && featConfig.Get("sampleRate").IsNumber()) {
config.sample_rate =
featConfig.Get("sampleRate").As<Napi::Number>().Int32Value();
}
SHERPA_ONNX_ASSIGN_ATTR_INT32(sample_rate, sampleRate);
SHERPA_ONNX_ASSIGN_ATTR_INT32(feature_dim, featureDim);
if (featConfig.Has("featureDim") && featConfig.Get("featureDim").IsNumber()) {
config.feature_dim =
featConfig.Get("featureDim").As<Napi::Number>().Int32Value();
}
return config;
return c;
}
/*
{
@@ -47,192 +41,103 @@ SherpaOnnxFeatureConfig GetFeatureConfig(Napi::Object obj) {
static SherpaOnnxOnlineTransducerModelConfig GetOnlineTransducerModelConfig(
Napi::Object obj) {
SherpaOnnxOnlineTransducerModelConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxOnlineTransducerModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("transducer") || !obj.Get("transducer").IsObject()) {
return config;
return c;
}
Napi::Object o = obj.Get("transducer").As<Napi::Object>();
if (o.Has("encoder") && o.Get("encoder").IsString()) {
Napi::String encoder = o.Get("encoder").As<Napi::String>();
std::string s = encoder.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
SHERPA_ONNX_ASSIGN_ATTR_STR(joiner, joiner);
config.encoder = p;
}
if (o.Has("decoder") && o.Get("decoder").IsString()) {
Napi::String decoder = o.Get("decoder").As<Napi::String>();
std::string s = decoder.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.decoder = p;
}
if (o.Has("joiner") && o.Get("joiner").IsString()) {
Napi::String joiner = o.Get("joiner").As<Napi::String>();
std::string s = joiner.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.joiner = p;
}
return config;
return c;
}
static SherpaOnnxOnlineZipformer2CtcModelConfig
GetOnlineZipformer2CtcModelConfig(Napi::Object obj) {
SherpaOnnxOnlineZipformer2CtcModelConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxOnlineZipformer2CtcModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("zipformer2Ctc") || !obj.Get("zipformer2Ctc").IsObject()) {
return config;
return c;
}
Napi::Object o = obj.Get("zipformer2Ctc").As<Napi::Object>();
if (o.Has("model") && o.Get("model").IsString()) {
Napi::String model = o.Get("model").As<Napi::String>();
std::string s = model.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
SHERPA_ONNX_ASSIGN_ATTR_STR(model, model);
config.model = p;
}
return config;
return c;
}
static SherpaOnnxOnlineParaformerModelConfig GetOnlineParaformerModelConfig(
Napi::Object obj) {
SherpaOnnxOnlineParaformerModelConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxOnlineParaformerModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("paraformer") || !obj.Get("paraformer").IsObject()) {
return config;
return c;
}
Napi::Object o = obj.Get("paraformer").As<Napi::Object>();
if (o.Has("encoder") && o.Get("encoder").IsString()) {
Napi::String encoder = o.Get("encoder").As<Napi::String>();
std::string s = encoder.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
SHERPA_ONNX_ASSIGN_ATTR_STR(encoder, encoder);
SHERPA_ONNX_ASSIGN_ATTR_STR(decoder, decoder);
config.encoder = p;
}
if (o.Has("decoder") && o.Get("decoder").IsString()) {
Napi::String decoder = o.Get("decoder").As<Napi::String>();
std::string s = decoder.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.decoder = p;
}
return config;
return c;
}
static SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) {
SherpaOnnxOnlineModelConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxOnlineModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("modelConfig") || !obj.Get("modelConfig").IsObject()) {
return config;
return c;
}
Napi::Object o = obj.Get("modelConfig").As<Napi::Object>();
config.transducer = GetOnlineTransducerModelConfig(o);
config.paraformer = GetOnlineParaformerModelConfig(o);
config.zipformer2_ctc = GetOnlineZipformer2CtcModelConfig(o);
c.transducer = GetOnlineTransducerModelConfig(o);
c.paraformer = GetOnlineParaformerModelConfig(o);
c.zipformer2_ctc = GetOnlineZipformer2CtcModelConfig(o);
if (o.Has("tokens") && o.Get("tokens").IsString()) {
Napi::String tokens = o.Get("tokens").As<Napi::String>();
std::string s = tokens.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.tokens = p;
}
if (o.Has("numThreads") && o.Get("numThreads").IsNumber()) {
config.num_threads = o.Get("numThreads").As<Napi::Number>().Int32Value();
}
if (o.Has("provider") && o.Get("provider").IsString()) {
Napi::String provider = o.Get("provider").As<Napi::String>();
std::string s = provider.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
config.provider = p;
}
SHERPA_ONNX_ASSIGN_ATTR_STR(tokens, tokens);
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);
if (o.Has("debug") &&
(o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) {
if (o.Get("debug").IsBoolean()) {
config.debug = o.Get("debug").As<Napi::Boolean>().Value();
c.debug = o.Get("debug").As<Napi::Boolean>().Value();
} else {
config.debug = o.Get("debug").As<Napi::Number>().Int32Value();
c.debug = o.Get("debug").As<Napi::Number>().Int32Value();
}
}
if (o.Has("modelType") && o.Get("modelType").IsString()) {
Napi::String model_type = o.Get("modelType").As<Napi::String>();
std::string s = model_type.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType);
config.model_type = p;
}
return config;
return c;
}
static SherpaOnnxOnlineCtcFstDecoderConfig GetCtcFstDecoderConfig(
Napi::Object obj) {
SherpaOnnxOnlineCtcFstDecoderConfig config;
memset(&config, 0, sizeof(config));
SherpaOnnxOnlineCtcFstDecoderConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("ctcFstDecoderConfig") ||
!obj.Get("ctcFstDecoderConfig").IsObject()) {
return config;
return c;
}
Napi::Object o = obj.Get("ctcFstDecoderConfig").As<Napi::Object>();
if (o.Has("graph") && o.Get("graph").IsString()) {
Napi::String graph = o.Get("graph").As<Napi::String>();
std::string s = graph.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
SHERPA_ONNX_ASSIGN_ATTR_STR(graph, graph);
SHERPA_ONNX_ASSIGN_ATTR_INT32(max_active, maxActive);
config.graph = p;
}
if (o.Has("maxActive") && o.Get("maxActive").IsNumber()) {
config.max_active = o.Get("maxActive").As<Napi::Number>().Int32Value();
}
return config;
return c;
}
static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
@@ -254,75 +159,36 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper(
return {};
}
Napi::Object config = info[0].As<Napi::Object>();
Napi::Object o = info[0].As<Napi::Object>();
SherpaOnnxOnlineRecognizerConfig c;
memset(&c, 0, sizeof(c));
c.feat_config = GetFeatureConfig(config);
c.model_config = GetOnlineModelConfig(config);
c.feat_config = GetFeatureConfig(o);
c.model_config = GetOnlineModelConfig(o);
if (config.Has("decodingMethod") && config.Get("decodingMethod").IsString()) {
Napi::String decoding_method =
config.Get("decodingMethod").As<Napi::String>();
std::string s = decoding_method.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
c.decoding_method = p;
}
if (config.Has("maxActivePaths") && config.Get("maxActivePaths").IsNumber()) {
c.max_active_paths =
config.Get("maxActivePaths").As<Napi::Number>().Int32Value();
}
SHERPA_ONNX_ASSIGN_ATTR_STR(decoding_method, decodingMethod);
SHERPA_ONNX_ASSIGN_ATTR_INT32(max_active_paths, maxActivePaths);
// enableEndpoint can be either a boolean or an integer
if (config.Has("enableEndpoint") &&
(config.Get("enableEndpoint").IsNumber() ||
config.Get("enableEndpoint").IsBoolean())) {
if (config.Get("enableEndpoint").IsNumber()) {
if (o.Has("enableEndpoint") && (o.Get("enableEndpoint").IsNumber() ||
o.Get("enableEndpoint").IsBoolean())) {
if (o.Get("enableEndpoint").IsNumber()) {
c.enable_endpoint =
config.Get("enableEndpoint").As<Napi::Number>().Int32Value();
o.Get("enableEndpoint").As<Napi::Number>().Int32Value();
} else {
c.enable_endpoint =
config.Get("enableEndpoint").As<Napi::Boolean>().Value();
c.enable_endpoint = o.Get("enableEndpoint").As<Napi::Boolean>().Value();
}
}
if (config.Has("rule1MinTrailingSilence") &&
config.Get("rule1MinTrailingSilence").IsNumber()) {
c.rule1_min_trailing_silence =
config.Get("rule1MinTrailingSilence").As<Napi::Number>().FloatValue();
}
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(rule1_min_trailing_silence,
rule1MinTrailingSilence);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(rule2_min_trailing_silence,
rule2MinTrailingSilence);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(rule3_min_utterance_length,
rule3MinUtteranceLength);
SHERPA_ONNX_ASSIGN_ATTR_STR(hotwords_file, hotwordsFile);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(hotwords_score, hotwordsScore);
if (config.Has("rule2MinTrailingSilence") &&
config.Get("rule2MinTrailingSilence").IsNumber()) {
c.rule2_min_trailing_silence =
config.Get("rule2MinTrailingSilence").As<Napi::Number>().FloatValue();
}
if (config.Has("rule3MinUtteranceLength") &&
config.Get("rule3MinUtteranceLength").IsNumber()) {
c.rule3_min_utterance_length =
config.Get("rule3MinUtteranceLength").As<Napi::Number>().FloatValue();
}
if (config.Has("hotwordsFile") && config.Get("hotwordsFile").IsString()) {
Napi::String hotwords_file = config.Get("hotwordsFile").As<Napi::String>();
std::string s = hotwords_file.Utf8Value();
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = 0;
c.hotwords_file = p;
}
if (config.Has("hotwordsScore") && config.Get("hotwordsScore").IsNumber()) {
c.hotwords_score =
config.Get("hotwordsScore").As<Napi::Number>().FloatValue();
}
c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(config);
c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(o);
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&c);