Fix java tests.
This commit is contained in:
Fangjun Kuang
2024-02-26 13:49:37 +08:00
committed by GitHub
parent ee37d9bd92
commit fb04366179
15 changed files with 561 additions and 532 deletions

2
java-api-examples/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
lib
hs_err*

View File

@@ -9,10 +9,11 @@ LIB_FILES = \
$(LIB_SRC_DIR)/OnlineLMConfig.java \ $(LIB_SRC_DIR)/OnlineLMConfig.java \
$(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
$(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \ $(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \
$(LIB_SRC_DIR)/OnlineZipformer2CtcModelConfig.java \
$(LIB_SRC_DIR)/OnlineModelConfig.java \ $(LIB_SRC_DIR)/OnlineModelConfig.java \
$(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
$(LIB_SRC_DIR)/OnlineStream.java \ $(LIB_SRC_DIR)/OnlineStream.java \
$(LIB_SRC_DIR)/OnlineRecognizer.java \ $(LIB_SRC_DIR)/OnlineRecognizer.java
WEBSOCKET_DIR:= ./src/websocketsrv WEBSOCKET_DIR:= ./src/websocketsrv
WEBSOCKET_FILES = \ WEBSOCKET_FILES = \
@@ -63,7 +64,7 @@ clean:
mkdir -p $(BUILD_DIR) mkdir -p $(BUILD_DIR)
mkdir -p ./lib mkdir -p ./lib
runfile: runfile: packjar buildfile
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
runhotwords: runhotwords:
@@ -85,7 +86,6 @@ buildlib: $(LIB_FILES:.java=.class)
%.class: %.java %.class: %.java
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
buildwebsocket: $(WEBSOCKET_FILES:.java=.class) buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
@@ -95,7 +95,7 @@ buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $<
packjar: packjar: buildlib
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket all: clean buildlib packjar buildfile buildmic downjar buildwebsocket

2
sherpa-onnx/java-api/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.idea
java-api.iml

View File

@@ -7,6 +7,7 @@ package com.k2fsa.sherpa.onnx;
public class OnlineModelConfig { public class OnlineModelConfig {
private final OnlineParaformerModelConfig paraformer; private final OnlineParaformerModelConfig paraformer;
private final OnlineTransducerModelConfig transducer; private final OnlineTransducerModelConfig transducer;
private final OnlineZipformer2CtcModelConfig zipformer2Ctc;
private final String tokens; private final String tokens;
private final int numThreads; private final int numThreads;
private final boolean debug; private final boolean debug;
@@ -19,7 +20,9 @@ public class OnlineModelConfig {
boolean debug, boolean debug,
String modelType, String modelType,
OnlineParaformerModelConfig paraformer, OnlineParaformerModelConfig paraformer,
OnlineTransducerModelConfig transducer) { OnlineTransducerModelConfig transducer,
OnlineZipformer2CtcModelConfig zipformer2Ctc
) {
this.tokens = tokens; this.tokens = tokens;
this.numThreads = numThreads; this.numThreads = numThreads;
@@ -27,6 +30,7 @@ public class OnlineModelConfig {
this.modelType = modelType; this.modelType = modelType;
this.paraformer = paraformer; this.paraformer = paraformer;
this.transducer = transducer; this.transducer = transducer;
this.zipformer2Ctc = zipformer2Ctc;
} }
public OnlineParaformerModelConfig getParaformer() { public OnlineParaformerModelConfig getParaformer() {

View File

@@ -32,8 +32,14 @@ usage example:
*/ */
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
import java.io.*; import java.io.BufferedInputStream;
import java.util.*; import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
public class OnlineRecognizer { public class OnlineRecognizer {
private long ptr = 0; // this is the asr engine ptrss private long ptr = 0; // this is the asr engine ptrss
@@ -71,6 +77,7 @@ public class OnlineRecognizer {
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(), proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim()); proMap.getOrDefault("joiner", "").trim());
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg = OnlineModelConfig modelCfg =
new OnlineModelConfig( new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(), proMap.getOrDefault("tokens", "").trim(),
@@ -78,7 +85,7 @@ public class OnlineRecognizer {
false, false,
proMap.getOrDefault("model_type", "zipformer").trim(), proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg, modelParaCfg,
modelTranCfg); modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig = FeatureConfig featConfig =
new FeatureConfig( new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
@@ -136,6 +143,7 @@ public class OnlineRecognizer {
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(), proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim()); proMap.getOrDefault("joiner", "").trim());
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg = OnlineModelConfig modelCfg =
new OnlineModelConfig( new OnlineModelConfig(
@@ -144,7 +152,7 @@ public class OnlineRecognizer {
false, false,
proMap.getOrDefault("model_type", "zipformer").trim(), proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg, modelParaCfg,
modelTranCfg); modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig = FeatureConfig featConfig =
new FeatureConfig( new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
@@ -201,8 +209,9 @@ public class OnlineRecognizer {
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
OnlineTransducerModelConfig modelTranCfg = OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(encoder, decoder, joiner); new OnlineTransducerModelConfig(encoder, decoder, joiner);
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg = OnlineModelConfig modelCfg =
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg); new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
OnlineRecognizerConfig rcgCfg = OnlineRecognizerConfig rcgCfg =
@@ -220,6 +229,32 @@ public class OnlineRecognizer {
this.ptr = createOnlineRecognizer(new Object(), rcgCfg); this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
} }
public static float[] readWavFile(String fileName) {
// read data from the filename
Object[] wavdata = readWave(fileName);
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
float[] floatData = (float[]) data;
return floatData;
}
// load the libsherpa-onnx-jni.so lib
public static void loadSoLib(String soPath) {
// load libsherpa-onnx-jni.so lib from the path
System.out.println("so lib path=" + soPath + "\n");
System.load(soPath.trim());
System.out.println("load so lib succeed\n");
}
public static void setSoPath(String soPath) {
OnlineRecognizer.loadSoLib(soPath);
OnlineStream.loadSoLib(soPath);
}
private static native Object[] readWave(String fileName); // static
private Map<String, String> readProperties(String modelCfgPath) { private Map<String, String> readProperties(String modelCfgPath) {
// read and parse config file // read and parse config file
Properties props = new Properties(); Properties props = new Properties();
@@ -302,30 +337,6 @@ public class OnlineRecognizer {
return stream; return stream;
} }
public static float[] readWavFile(String fileName) {
// read data from the filename
Object[] wavdata = readWave(fileName);
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
float[] floatData = (float[]) data;
return floatData;
}
// load the libsherpa-onnx-jni.so lib
public static void loadSoLib(String soPath) {
// load libsherpa-onnx-jni.so lib from the path
System.out.println("so lib path=" + soPath + "\n");
System.load(soPath.trim());
System.out.println("load so lib succeed\n");
}
public static void setSoPath(String soPath) {
OnlineRecognizer.loadSoLib(soPath);
OnlineStream.loadSoLib(soPath);
}
protected void finalize() throws Throwable { protected void finalize() throws Throwable {
release(); release();
} }
@@ -337,15 +348,13 @@ public class OnlineRecognizer {
this.ptr = 0; this.ptr = 0;
} }
// JNI interface libsherpa-onnx-jni.so
// stream release, you'd better call it manually if not use anymore // stream release, you'd better call it manually if not use anymore
public void releaseStream(OnlineStream s) { public void releaseStream(OnlineStream s) {
s.release(); s.release();
} }
// JNI interface libsherpa-onnx-jni.so
private static native Object[] readWave(String fileName); // static
private native String getResult(long ptr, long streamPtr); private native String getResult(long ptr, long streamPtr);
private native void decodeStream(long ptr, long streamPtr); private native void decodeStream(long ptr, long streamPtr);

View File

@@ -4,19 +4,22 @@
// Stream is used for feeding data to the asr engine // Stream is used for feeding data to the asr engine
package com.k2fsa.sherpa.onnx; package com.k2fsa.sherpa.onnx;
import java.io.*;
import java.util.*;
public class OnlineStream { public class OnlineStream {
private long ptr = 0; // this is the stream ptr private long ptr = 0; // this is the stream ptr
private int sampleRate = 16000; private int sampleRate = 16000;
// assign ptr to this stream in construction // assign ptr to this stream in construction
public OnlineStream(long ptr, int sampleRate) { public OnlineStream(long ptr, int sampleRate) {
this.ptr = ptr; this.ptr = ptr;
this.sampleRate = sampleRate; this.sampleRate = sampleRate;
} }
public static void loadSoLib(String soPath) {
// load .so lib from the path
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
}
public long getPtr() { public long getPtr() {
return ptr; return ptr;
} }
@@ -31,18 +34,13 @@ public class OnlineStream {
public void inputFinished() { public void inputFinished() {
// add some tail padding // add some tail padding
int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
float tailPaddings[] = new float[padLen]; // default value is 0 float[] tailPaddings = new float[padLen]; // default value is 0
acceptWaveform(this.ptr, this.sampleRate, tailPaddings); acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
// tell the engine all data are feeded // tell the engine all data are feeded
inputFinished(this.ptr); inputFinished(this.ptr);
} }
public static void loadSoLib(String soPath) {
// load .so lib from the path
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
}
public void release() { public void release() {
// stream object must be release after used // stream object must be release after used
if (this.ptr == 0) return; if (this.ptr == 0) return;

View File

@@ -0,0 +1,14 @@
package com.k2fsa.sherpa.onnx;
public class OnlineZipformer2CtcModelConfig {
private final String model;
public OnlineZipformer2CtcModelConfig(String model) {
this.model = model;
}
public String getModel() {
return model;
}
}

View File

@@ -1522,8 +1522,8 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(
SHERPA_ONNX_EXTERN_C SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate,
jlong ptr, jboolean recreate, jstring keywords) { jstring keywords) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
model->Reset(recreate, p_keywords); model->Reset(recreate, p_keywords);