diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index 18576562..010be1f2 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig( var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search", var maxActivePaths: Int = 4, + var hotwordsFile: String = "", + var hotwordsScore: Float = 1.5f, ) class SherpaOnnx( diff --git a/java-api-examples/Makefile b/java-api-examples/Makefile index 9b8a18ef..4643ca74 100755 --- a/java-api-examples/Makefile +++ b/java-api-examples/Makefile @@ -1,4 +1,3 @@ - ENTRY_POINT = ./ LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx @@ -65,18 +64,22 @@ clean: mkdir -p ./lib runfile: + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile +runhotwords: + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav runmic: - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic runsrv: - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg runclient: - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 + +runclienthotwords: + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32 buildlib: $(LIB_FILES:.java=.class) diff --git a/java-api-examples/modelconfig.cfg b/java-api-examples/modelconfig.cfg index 032c749d..d1ed3b2d 100755 --- a/java-api-examples/modelconfig.cfg +++ b/java-api-examples/modelconfig.cfg @@ -12,6 +12,8 @@ num_threads=4 enable_endpoint_detection=true decoding_method=modified_beam_search max_active_paths=4 +hotwords_file= +hotwords_score=1.5 lm_model= lm_scale=0.5 model_type=zipformer diff --git a/java-api-examples/runtest.sh b/java-api-examples/runtest.sh index 7672319d..82f763c6 100755 --- a/java-api-examples/runtest.sh +++ b/java-api-examples/runtest.sh @@ -36,6 +36,8 @@ if [ ! -d $repo ];then git lfs pull --include "*.onnx" ls -lh *.onnx popd + ln -s $repo/test_wavs/0.wav hotwords.wav + fi log $(pwd) @@ -64,3 +66,9 @@ cd ../java-api-examples make all make runfile + +echo "礼 拜 二" > hotwords.txt + +sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg + +make runhotwords diff --git a/java-api-examples/src/DecodeFile.java b/java-api-examples/src/DecodeFile.java index 344d83b7..c12cf3a8 100644 --- a/java-api-examples/src/DecodeFile.java +++ b/java-api-examples/src/DecodeFile.java @@ -49,6 +49,8 @@ public class DecodeFile { float rule3MinUtteranceLength = 20F; String decodingMethod = "greedy_search"; int maxActivePaths = 4; + String hotwordsFile = ""; + float hotwordsScore = 1.5F; String lm_model = ""; float lm_scale = 0.5F; String modelType = "zipformer"; @@ -69,6 +71,8 @@ public class DecodeFile { lm_model, lm_scale, maxActivePaths, + hotwordsFile, + hotwordsScore, modelType); streamObj = rcgOjb.createStream(); } catch (Exception e) { @@ -158,7 +162,7 @@ public class DecodeFile { try { String appDir = System.getProperty("user.dir"); System.out.println("appdir=" + appDir); - String fileName = appDir + "/test.wav"; + String fileName = appDir + "/" + args[0]; String cfgPath = appDir + "/modeltest.cfg"; String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so"; OnlineRecognizer.setSoPath(soPath); diff --git a/java-api-examples/src/websocketsrv/AsrWebsocketServer.java b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java index 17421e14..d20bb05e 100755 --- a/java-api-examples/src/websocketsrv/AsrWebsocketServer.java +++ b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java @@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer { } } - - public boolean streamQueueFind(WebSocket conn) { return streamQueue.contains(conn); } @@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer { rcgOjb = new OnlineRecognizer(cfgPath); // size of stream thread pool - int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num")); + int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16")); // size of decoder thread pool - int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num")); + int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16")); // time(ms) idle for decoder thread when no job - int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle")); + int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200")); // size of streams for parallel decoding - int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num")); + int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16")); // time(ms) out for connection data - int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out")); + int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000")); // create stream threads for (int i = 0; i < streamThreadNum; i++) { @@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer { String soPath = args[0]; String cfgPath = args[1]; - + OnlineRecognizer.setSoPath(soPath); logger.info("readProperties"); Map cfgMap = AsrWebsocketServer.readProperties(cfgPath); - int port = Integer.valueOf(cfgMap.get("port")); + int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890")); - int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); + int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16")); AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); logger.info("initModelWithCfg"); s.initModelWithCfg(cfgMap, cfgPath); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index f461fc7a..d064c75d 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -44,38 +44,48 @@ public class OnlineRecognizer { public OnlineRecognizer(String modelCfgPath) { Map proMap = this.readProperties(modelCfgPath); try { - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); this.sampleRate = sampleRate; EndpointRule rule1 = new EndpointRule( - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); + false, + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), + 0.0F); EndpointRule rule2 = new EndpointRule( - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); + true, + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), + 0.0F); EndpointRule rule3 = new EndpointRule( - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); + false, + 0.0F, + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim()); + OnlineParaformerModelConfig modelParaCfg = + new OnlineParaformerModelConfig( + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); OnlineTransducerModelConfig modelTranCfg = new OnlineTransducerModelConfig( - proMap.get("encoder").trim(), - proMap.get("decoder").trim(), - proMap.get("joiner").trim()); + proMap.getOrDefault("encoder", "").trim(), + proMap.getOrDefault("decoder", "").trim(), + proMap.getOrDefault("joiner", "").trim()); OnlineModelConfig modelCfg = new OnlineModelConfig( - proMap.get("tokens").trim(), - Integer.parseInt(proMap.get("num_threads").trim()), + proMap.getOrDefault("tokens", "").trim(), + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), false, - proMap.get("model_type").trim(), + proMap.getOrDefault("model_type", "zipformer").trim(), modelParaCfg, modelTranCfg); FeatureConfig featConfig = - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); + new FeatureConfig( + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); OnlineLMConfig onlineLmConfig = new OnlineLMConfig( - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); + proMap.getOrDefault("lm_model", "").trim(), + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( @@ -83,9 +93,11 @@ public class OnlineRecognizer { modelCfg, endCfg, onlineLmConfig, - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), - proMap.get("decoding_method").trim(), - Integer.parseInt(proMap.get("max_active_paths").trim())); + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), + proMap.getOrDefault("hotwords_file", "").trim(), + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); @@ -98,41 +110,49 @@ public class OnlineRecognizer { public OnlineRecognizer(Object assetManager, String modelCfgPath) { Map proMap = this.readProperties(modelCfgPath); try { - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); this.sampleRate = sampleRate; EndpointRule rule1 = new EndpointRule( - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); + false, + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), + 0.0F); EndpointRule rule2 = new EndpointRule( - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); + true, + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), + 0.0F); EndpointRule rule3 = new EndpointRule( - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); + false, + 0.0F, + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig( - proMap.get("encoder").trim(), proMap.get("decoder").trim()); + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); OnlineTransducerModelConfig modelTranCfg = new OnlineTransducerModelConfig( - proMap.get("encoder").trim(), - proMap.get("decoder").trim(), - proMap.get("joiner").trim()); + proMap.getOrDefault("encoder", "").trim(), + proMap.getOrDefault("decoder", "").trim(), + proMap.getOrDefault("joiner", "").trim()); OnlineModelConfig modelCfg = new OnlineModelConfig( - proMap.get("tokens").trim(), - Integer.parseInt(proMap.get("num_threads").trim()), + proMap.getOrDefault("tokens", "").trim(), + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), false, - proMap.get("model_type").trim(), + proMap.getOrDefault("model_type", "zipformer").trim(), modelParaCfg, modelTranCfg); FeatureConfig featConfig = - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); + new FeatureConfig( + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); OnlineLMConfig onlineLmConfig = new OnlineLMConfig( - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); + proMap.getOrDefault("lm_model", "").trim(), + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( @@ -140,9 +160,11 @@ public class OnlineRecognizer { modelCfg, endCfg, onlineLmConfig, - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), - proMap.get("decoding_method").trim(), - Integer.parseInt(proMap.get("max_active_paths").trim())); + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), + proMap.getOrDefault("hotwords_file", "").trim(), + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 this.ptr = createOnlineRecognizer(assetManager, rcgCfg); @@ -168,6 +190,8 @@ public class OnlineRecognizer { String lm_model, float lm_scale, int maxActivePaths, + String hotwordsFile, + float hotwordsScore, String modelType) { this.sampleRate = sampleRate; EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); @@ -189,7 +213,9 @@ public class OnlineRecognizer { onlineLmConfig, enableEndpointDetection, decodingMethod, - maxActivePaths); + maxActivePaths, + hotwordsFile, + hotwordsScore); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); } @@ -211,7 +237,6 @@ public class OnlineRecognizer { String key = (String) en.nextElement(); String Property = props.getProperty(key); proMap.put(key, Property); - // System.out.println(key+"="+Property); } } catch (Exception e) { diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 4462a708..0f1cdb81 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig { private final boolean enableEndpoint; private final String decodingMethod; private final int maxActivePaths; + private final String hotwordsFile; + private final float hotwordsScore; public OnlineRecognizerConfig( FeatureConfig featConfig, @@ -20,7 +22,9 @@ public class OnlineRecognizerConfig { OnlineLMConfig lmConfig, boolean enableEndpoint, String decodingMethod, - int maxActivePaths) { + int maxActivePaths, + String hotwordsFile, + float hotwordsScore) { this.featConfig = featConfig; this.modelConfig = modelConfig; this.endpointConfig = endpointConfig; @@ -28,6 +32,8 @@ public class OnlineRecognizerConfig { this.enableEndpoint = enableEndpoint; this.decodingMethod = decodingMethod; this.maxActivePaths = maxActivePaths; + this.hotwordsFile = hotwordsFile; + this.hotwordsScore = hotwordsScore; } public OnlineLMConfig getLmConfig() { diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index b546a0ed..785a4c48 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "maxActivePaths", "I"); ans.max_active_paths = env->GetIntField(config, fid); + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + //---------- feat config ---------- fid = env->GetFieldID(cls, "featConfig", "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "maxActivePaths", "I"); ans.max_active_paths = env->GetIntField(config, fid); + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + //---------- feat config ---------- fid = env->GetFieldID(cls, "featConfig", "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");