Add java api for hotwords (#319)

* Add java api

* support websocket

* Fix kotlin
This commit is contained in:
Wei Kang
2023-09-18 22:44:29 +08:00
committed by GitHub
parent 4dfc11066a
commit d7eab95439
9 changed files with 117 additions and 51 deletions

View File

@@ -1,4 +1,3 @@
ENTRY_POINT = ./
LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx
@@ -65,18 +64,22 @@ clean:
mkdir -p ./lib
runfile:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile
runhotwords:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
runmic:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
runsrv:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
runclient:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
runclienthotwords:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32
buildlib: $(LIB_FILES:.java=.class)

View File

@@ -12,6 +12,8 @@ num_threads=4
enable_endpoint_detection=true
decoding_method=modified_beam_search
max_active_paths=4
hotwords_file=
hotwords_score=1.5
lm_model=
lm_scale=0.5
model_type=zipformer

View File

@@ -36,6 +36,8 @@ if [ ! -d $repo ];then
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
ln -s $repo/test_wavs/0.wav hotwords.wav
fi
log $(pwd)
@@ -64,3 +66,9 @@ cd ../java-api-examples
make all
make runfile
echo "礼 拜 二" > hotwords.txt
sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg
make runhotwords

View File

@@ -49,6 +49,8 @@ public class DecodeFile {
float rule3MinUtteranceLength = 20F;
String decodingMethod = "greedy_search";
int maxActivePaths = 4;
String hotwordsFile = "";
float hotwordsScore = 1.5F;
String lm_model = "";
float lm_scale = 0.5F;
String modelType = "zipformer";
@@ -69,6 +71,8 @@ public class DecodeFile {
lm_model,
lm_scale,
maxActivePaths,
hotwordsFile,
hotwordsScore,
modelType);
streamObj = rcgOjb.createStream();
} catch (Exception e) {
@@ -158,7 +162,7 @@ public class DecodeFile {
try {
String appDir = System.getProperty("user.dir");
System.out.println("appdir=" + appDir);
String fileName = appDir + "/test.wav";
String fileName = appDir + "/" + args[0];
String cfgPath = appDir + "/modeltest.cfg";
String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so";
OnlineRecognizer.setSoPath(soPath);

View File

@@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer {
}
}
public boolean streamQueueFind(WebSocket conn) {
return streamQueue.contains(conn);
}
@@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer {
rcgOjb = new OnlineRecognizer(cfgPath);
// size of stream thread pool
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
// size of decoder thread pool
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));
// time(ms) idle for decoder thread when no job
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
// size of streams for parallel decoding
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
// time(ms) out for connection data
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));
// create stream threads
for (int i = 0; i < streamThreadNum; i++) {
@@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer {
String soPath = args[0];
String cfgPath = args[1];
OnlineRecognizer.setSoPath(soPath);
logger.info("readProperties");
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
int port = Integer.valueOf(cfgMap.get("port"));
int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
logger.info("initModelWithCfg");
s.initModelWithCfg(cfgMap, cfgPath);