Add java api for hotwords (#319)
* Add java api * support websocket * Fix kotlin
This commit is contained in:
@@ -53,6 +53,8 @@ data class OnlineRecognizerConfig(
|
||||
var enableEndpoint: Boolean = true,
|
||||
var decodingMethod: String = "greedy_search",
|
||||
var maxActivePaths: Int = 4,
|
||||
var hotwordsFile: String = "",
|
||||
var hotwordsScore: Float = 1.5f,
|
||||
)
|
||||
|
||||
class SherpaOnnx(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
ENTRY_POINT = ./
|
||||
|
||||
LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx
|
||||
@@ -65,18 +64,22 @@ clean:
|
||||
mkdir -p ./lib
|
||||
|
||||
runfile:
|
||||
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
|
||||
|
||||
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile
|
||||
runhotwords:
|
||||
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
|
||||
|
||||
runmic:
|
||||
|
||||
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
|
||||
|
||||
runsrv:
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
|
||||
|
||||
runclient:
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
|
||||
|
||||
runclienthotwords:
|
||||
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32
|
||||
|
||||
buildlib: $(LIB_FILES:.java=.class)
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ num_threads=4
|
||||
enable_endpoint_detection=true
|
||||
decoding_method=modified_beam_search
|
||||
max_active_paths=4
|
||||
hotwords_file=
|
||||
hotwords_score=1.5
|
||||
lm_model=
|
||||
lm_scale=0.5
|
||||
model_type=zipformer
|
||||
|
||||
@@ -36,6 +36,8 @@ if [ ! -d $repo ];then
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
ln -s $repo/test_wavs/0.wav hotwords.wav
|
||||
|
||||
fi
|
||||
|
||||
log $(pwd)
|
||||
@@ -64,3 +66,9 @@ cd ../java-api-examples
|
||||
make all
|
||||
|
||||
make runfile
|
||||
|
||||
echo "礼 拜 二" > hotwords.txt
|
||||
|
||||
sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg
|
||||
|
||||
make runhotwords
|
||||
|
||||
@@ -49,6 +49,8 @@ public class DecodeFile {
|
||||
float rule3MinUtteranceLength = 20F;
|
||||
String decodingMethod = "greedy_search";
|
||||
int maxActivePaths = 4;
|
||||
String hotwordsFile = "";
|
||||
float hotwordsScore = 1.5F;
|
||||
String lm_model = "";
|
||||
float lm_scale = 0.5F;
|
||||
String modelType = "zipformer";
|
||||
@@ -69,6 +71,8 @@ public class DecodeFile {
|
||||
lm_model,
|
||||
lm_scale,
|
||||
maxActivePaths,
|
||||
hotwordsFile,
|
||||
hotwordsScore,
|
||||
modelType);
|
||||
streamObj = rcgOjb.createStream();
|
||||
} catch (Exception e) {
|
||||
@@ -158,7 +162,7 @@ public class DecodeFile {
|
||||
try {
|
||||
String appDir = System.getProperty("user.dir");
|
||||
System.out.println("appdir=" + appDir);
|
||||
String fileName = appDir + "/test.wav";
|
||||
String fileName = appDir + "/" + args[0];
|
||||
String cfgPath = appDir + "/modeltest.cfg";
|
||||
String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so";
|
||||
OnlineRecognizer.setSoPath(soPath);
|
||||
|
||||
@@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
public boolean streamQueueFind(WebSocket conn) {
|
||||
return streamQueue.contains(conn);
|
||||
}
|
||||
@@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer {
|
||||
|
||||
rcgOjb = new OnlineRecognizer(cfgPath);
|
||||
// size of stream thread pool
|
||||
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
|
||||
int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
|
||||
// size of decoder thread pool
|
||||
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
|
||||
int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));
|
||||
|
||||
// time(ms) idle for decoder thread when no job
|
||||
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
|
||||
int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
|
||||
// size of streams for parallel decoding
|
||||
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
|
||||
int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
|
||||
// time(ms) out for connection data
|
||||
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
|
||||
int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));
|
||||
|
||||
// create stream threads
|
||||
for (int i = 0; i < streamThreadNum; i++) {
|
||||
@@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer {
|
||||
|
||||
String soPath = args[0];
|
||||
String cfgPath = args[1];
|
||||
|
||||
|
||||
OnlineRecognizer.setSoPath(soPath);
|
||||
logger.info("readProperties");
|
||||
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
|
||||
int port = Integer.valueOf(cfgMap.get("port"));
|
||||
int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));
|
||||
|
||||
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
|
||||
int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
|
||||
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
|
||||
logger.info("initModelWithCfg");
|
||||
s.initModelWithCfg(cfgMap, cfgPath);
|
||||
|
||||
@@ -44,38 +44,48 @@ public class OnlineRecognizer {
|
||||
public OnlineRecognizer(String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
|
||||
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim());
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.get("encoder").trim(),
|
||||
proMap.get("decoder").trim(),
|
||||
proMap.get("joiner").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.get("model_type").trim(),
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
@@ -83,9 +93,11 @@ public class OnlineRecognizer {
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
|
||||
@@ -98,41 +110,49 @@ public class OnlineRecognizer {
|
||||
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
|
||||
Map<String, String> proMap = this.readProperties(modelCfgPath);
|
||||
try {
|
||||
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
|
||||
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 =
|
||||
new EndpointRule(
|
||||
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
|
||||
false,
|
||||
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule2 =
|
||||
new EndpointRule(
|
||||
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
|
||||
true,
|
||||
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
|
||||
0.0F);
|
||||
EndpointRule rule3 =
|
||||
new EndpointRule(
|
||||
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
|
||||
false,
|
||||
0.0F,
|
||||
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
|
||||
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
|
||||
OnlineParaformerModelConfig modelParaCfg =
|
||||
new OnlineParaformerModelConfig(
|
||||
proMap.get("encoder").trim(), proMap.get("decoder").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
|
||||
OnlineTransducerModelConfig modelTranCfg =
|
||||
new OnlineTransducerModelConfig(
|
||||
proMap.get("encoder").trim(),
|
||||
proMap.get("decoder").trim(),
|
||||
proMap.get("joiner").trim());
|
||||
proMap.getOrDefault("encoder", "").trim(),
|
||||
proMap.getOrDefault("decoder", "").trim(),
|
||||
proMap.getOrDefault("joiner", "").trim());
|
||||
|
||||
OnlineModelConfig modelCfg =
|
||||
new OnlineModelConfig(
|
||||
proMap.get("tokens").trim(),
|
||||
Integer.parseInt(proMap.get("num_threads").trim()),
|
||||
proMap.getOrDefault("tokens", "").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
|
||||
false,
|
||||
proMap.get("model_type").trim(),
|
||||
proMap.getOrDefault("model_type", "zipformer").trim(),
|
||||
modelParaCfg,
|
||||
modelTranCfg);
|
||||
FeatureConfig featConfig =
|
||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||
new FeatureConfig(
|
||||
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
|
||||
|
||||
OnlineLMConfig onlineLmConfig =
|
||||
new OnlineLMConfig(
|
||||
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||
proMap.getOrDefault("lm_model", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
|
||||
|
||||
OnlineRecognizerConfig rcgCfg =
|
||||
new OnlineRecognizerConfig(
|
||||
@@ -140,9 +160,11 @@ public class OnlineRecognizer {
|
||||
modelCfg,
|
||||
endCfg,
|
||||
onlineLmConfig,
|
||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||
proMap.get("decoding_method").trim(),
|
||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
|
||||
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
|
||||
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
|
||||
proMap.getOrDefault("hotwords_file", "").trim(),
|
||||
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
|
||||
|
||||
@@ -168,6 +190,8 @@ public class OnlineRecognizer {
|
||||
String lm_model,
|
||||
float lm_scale,
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore,
|
||||
String modelType) {
|
||||
this.sampleRate = sampleRate;
|
||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||
@@ -189,7 +213,9 @@ public class OnlineRecognizer {
|
||||
onlineLmConfig,
|
||||
enableEndpointDetection,
|
||||
decodingMethod,
|
||||
maxActivePaths);
|
||||
maxActivePaths,
|
||||
hotwordsFile,
|
||||
hotwordsScore);
|
||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||
}
|
||||
@@ -211,7 +237,6 @@ public class OnlineRecognizer {
|
||||
String key = (String) en.nextElement();
|
||||
String Property = props.getProperty(key);
|
||||
proMap.put(key, Property);
|
||||
// System.out.println(key+"="+Property);
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
|
||||
private final boolean enableEndpoint;
|
||||
private final String decodingMethod;
|
||||
private final int maxActivePaths;
|
||||
private final String hotwordsFile;
|
||||
private final float hotwordsScore;
|
||||
|
||||
public OnlineRecognizerConfig(
|
||||
FeatureConfig featConfig,
|
||||
@@ -20,7 +22,9 @@ public class OnlineRecognizerConfig {
|
||||
OnlineLMConfig lmConfig,
|
||||
boolean enableEndpoint,
|
||||
String decodingMethod,
|
||||
int maxActivePaths) {
|
||||
int maxActivePaths,
|
||||
String hotwordsFile,
|
||||
float hotwordsScore) {
|
||||
this.featConfig = featConfig;
|
||||
this.modelConfig = modelConfig;
|
||||
this.endpointConfig = endpointConfig;
|
||||
@@ -28,6 +32,8 @@ public class OnlineRecognizerConfig {
|
||||
this.enableEndpoint = enableEndpoint;
|
||||
this.decodingMethod = decodingMethod;
|
||||
this.maxActivePaths = maxActivePaths;
|
||||
this.hotwordsFile = hotwordsFile;
|
||||
this.hotwordsScore = hotwordsScore;
|
||||
}
|
||||
|
||||
public OnlineLMConfig getLmConfig() {
|
||||
|
||||
@@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hotwords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsScore", "F");
|
||||
ans.hotwords_score = env->GetFloatField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
@@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hotwords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsScore", "F");
|
||||
ans.hotwords_score = env->GetFloatField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
|
||||
Reference in New Issue
Block a user