Add java websocket support (#137)
* add decode example for mic * some changes to README.md * add java websocket srv * change to readwav to static * make some changes to code comments * little change for readme.md * fix bug about multiple threads * made little modification * add protocol in readme, removed static Queue and add lmConfig --------- Co-authored-by: root <root@localhost.localdomain>
This commit is contained in:
@@ -7,11 +7,21 @@ LIB_FILES = \
|
|||||||
$(LIB_SRC_DIR)/EndpointRule.java \
|
$(LIB_SRC_DIR)/EndpointRule.java \
|
||||||
$(LIB_SRC_DIR)/EndpointConfig.java \
|
$(LIB_SRC_DIR)/EndpointConfig.java \
|
||||||
$(LIB_SRC_DIR)/FeatureConfig.java \
|
$(LIB_SRC_DIR)/FeatureConfig.java \
|
||||||
|
$(LIB_SRC_DIR)/OnlineLMConfig.java \
|
||||||
$(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
|
$(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
|
||||||
$(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
|
$(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
|
||||||
$(LIB_SRC_DIR)/OnlineStream.java \
|
$(LIB_SRC_DIR)/OnlineStream.java \
|
||||||
$(LIB_SRC_DIR)/OnlineRecognizer.java \
|
$(LIB_SRC_DIR)/OnlineRecognizer.java \
|
||||||
|
|
||||||
|
WEBSOCKET_DIR:= ./src/websocketsrv
|
||||||
|
WEBSOCKET_FILES = \
|
||||||
|
$(WEBSOCKET_DIR)/ConnectionData.java \
|
||||||
|
$(WEBSOCKET_DIR)/DecoderThreadHandler.java \
|
||||||
|
$(WEBSOCKET_DIR)/StreamThreadHandler.java \
|
||||||
|
$(WEBSOCKET_DIR)/AsrWebsocketServer.java \
|
||||||
|
$(WEBSOCKET_DIR)/AsrWebsocketClient.java \
|
||||||
|
|
||||||
|
|
||||||
LIB_BUILD_DIR = ./lib
|
LIB_BUILD_DIR = ./lib
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +49,13 @@ buildmic:
|
|||||||
|
|
||||||
rebuild: clean all
|
rebuild: clean all
|
||||||
|
|
||||||
.PHONY: clean run
|
.PHONY: clean run downjar
|
||||||
|
|
||||||
|
downjar:
|
||||||
|
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
|
||||||
|
wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
|
||||||
|
wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
|
||||||
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -frv $(BUILD_DIR)/*
|
rm -frv $(BUILD_DIR)/*
|
||||||
@@ -56,6 +72,12 @@ runmic:
|
|||||||
|
|
||||||
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
|
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
|
||||||
|
|
||||||
|
runsrv:
|
||||||
|
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so ./modelconfig.cfg
|
||||||
|
|
||||||
|
runclient:
|
||||||
|
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
|
||||||
|
|
||||||
buildlib: $(LIB_FILES:.java=.class)
|
buildlib: $(LIB_FILES:.java=.class)
|
||||||
|
|
||||||
|
|
||||||
@@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class)
|
|||||||
|
|
||||||
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
|
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
|
||||||
|
|
||||||
|
buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
|
||||||
|
|
||||||
|
|
||||||
|
%.class: %.java
|
||||||
|
|
||||||
|
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $<
|
||||||
|
|
||||||
packjar:
|
packjar:
|
||||||
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
|
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
|
||||||
|
|
||||||
all: clean buildlib packjar buildfile buildmic
|
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
--------------
|
--------------
|
||||||
|
|
||||||
Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform.
|
Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform.
|
||||||
|
now support multiple threads for websocket server
|
||||||
|
|
||||||
```xml
|
```xml
|
||||||
Depend on:
|
Depend on:
|
||||||
@@ -35,7 +36,7 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362:
|
|||||||
|
|
||||||
3.Config model config.cfg
|
3.Config model config.cfg
|
||||||
-------------------------
|
-------------------------
|
||||||
|
/**change model path in config.cfg according to your env**/
|
||||||
```xml
|
```xml
|
||||||
#model config
|
#model config
|
||||||
sample_rate=16000
|
sample_rate=16000
|
||||||
@@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362:
|
|||||||
enable_endpoint_detection=false
|
enable_endpoint_detection=false
|
||||||
decoding_method=greedy_search
|
decoding_method=greedy_search
|
||||||
max_active_paths=4
|
max_active_paths=4
|
||||||
|
|
||||||
|
#websocket server config
|
||||||
|
port=8890
|
||||||
|
#number of threads pool for network io
|
||||||
|
connection_thread_num=16
|
||||||
|
#number of threads pool for stream
|
||||||
|
stream_thread_num=16
|
||||||
|
#number of threads pool for decoder
|
||||||
|
decoder_thread_num=16
|
||||||
|
#size of streams for parallel decoding
|
||||||
|
parallel_decoder_num=16
|
||||||
|
#time(ms) idle for decoder thread when no job
|
||||||
|
decoder_time_idle=10
|
||||||
|
#time(ms) out for connection data
|
||||||
|
deocder_time_out=3000
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar
|
|||||||
make runmic
|
make runmic
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
6.WebSocket Server
|
||||||
|
----------
|
||||||
|
|
||||||
|
support multiple threads for websocket server
|
||||||
|
6.0 Protocol for communication
|
||||||
|
1) client connect to server
|
||||||
|
```shell
|
||||||
|
ws client -> srv ws address
|
||||||
|
ws address example: ws://127.0.0.1:8889/
|
||||||
|
```
|
||||||
|
2) client send 16k pcm_s16le binary stream data to server
|
||||||
|
```shell
|
||||||
|
PCM sampleRate 16000
|
||||||
|
single channel
|
||||||
|
sampleSize 16bit
|
||||||
|
little endian
|
||||||
|
type short
|
||||||
|
```
|
||||||
|
3) client send "Done" text to server when all data is sent
|
||||||
|
```shell
|
||||||
|
ws_socket.send("Done")
|
||||||
|
```
|
||||||
|
4) client will receive json message from server whenever asr engine decoded new text
|
||||||
|
```shell
|
||||||
|
json example: {"text":"甚至出现交易几乎停滞的情况","eof":false"}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
6.1 Build
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd sherpa-onnx/java-api-examples
|
||||||
|
make all
|
||||||
|
```
|
||||||
|
|
||||||
|
6.2 Run srv example
|
||||||
|
|
||||||
|
usage: AsrWebsocketServer soPath modelCfgPath
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make runsrv /**change path in Makefile according to your env**/
|
||||||
|
```
|
||||||
|
|
||||||
|
6.3 Run multiple threads client example
|
||||||
|
|
||||||
|
usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads
|
||||||
|
|
||||||
|
json result example: {"text":"甚至出现交易几乎停滞的情况","eof":"true"}
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make runclient /**change path in Makefile according to your env**/
|
||||||
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-
|
|||||||
joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
|
joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
|
||||||
tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
||||||
num_threads=4
|
num_threads=4
|
||||||
enable_endpoint_detection=false
|
enable_endpoint_detection=true
|
||||||
decoding_method=modified_beam_search
|
decoding_method=modified_beam_search
|
||||||
max_active_paths=4
|
max_active_paths=4
|
||||||
|
lm_model=
|
||||||
|
lm_scale=0.5
|
||||||
|
|
||||||
|
#websocket server config
|
||||||
|
port=8890
|
||||||
|
connection_thread_num=16
|
||||||
|
stream_thread_num=16
|
||||||
|
decoder_thread_num=16
|
||||||
|
parallel_decoder_num=16
|
||||||
|
decoder_time_idle=200
|
||||||
|
deocder_time_out=30000
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ public class DecodeFile {
|
|||||||
float rule3MinUtteranceLength = 20F;
|
float rule3MinUtteranceLength = 20F;
|
||||||
String decodingMethod = "greedy_search";
|
String decodingMethod = "greedy_search";
|
||||||
int maxActivePaths = 4;
|
int maxActivePaths = 4;
|
||||||
|
String lm_model="";
|
||||||
|
float lm_scale=0.5F;
|
||||||
rcgOjb =
|
rcgOjb =
|
||||||
new OnlineRecognizer(
|
new OnlineRecognizer(
|
||||||
tokens,
|
tokens,
|
||||||
@@ -64,6 +65,8 @@ public class DecodeFile {
|
|||||||
rule2MinTrailingSilence,
|
rule2MinTrailingSilence,
|
||||||
rule3MinUtteranceLength,
|
rule3MinUtteranceLength,
|
||||||
decodingMethod,
|
decodingMethod,
|
||||||
|
lm_model,
|
||||||
|
lm_scale,
|
||||||
maxActivePaths);
|
maxActivePaths);
|
||||||
streamObj = rcgOjb.createStream();
|
streamObj = rcgOjb.createStream();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|||||||
131
java-api-examples/src/websocketsrv/AsrWebsocketClient.java
Executable file
131
java-api-examples/src/websocketsrv/AsrWebsocketClient.java
Executable file
@@ -0,0 +1,131 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaomingwork
|
||||||
|
*/
|
||||||
|
// java AsrWebsocketClient
|
||||||
|
// usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads
|
||||||
|
package websocketsrv;
|
||||||
|
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URISyntaxException;
|
||||||
|
import java.nio.*;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.java_websocket.client.WebSocketClient;
|
||||||
|
import org.java_websocket.drafts.Draft;
|
||||||
|
import org.java_websocket.handshake.ServerHandshake;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/** This example demonstrates how to connect to websocket server. */
|
||||||
|
public class AsrWebsocketClient extends WebSocketClient {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketClient.class);
|
||||||
|
|
||||||
|
public AsrWebsocketClient(URI serverUri, Draft draft) {
|
||||||
|
super(serverUri, draft);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsrWebsocketClient(URI serverURI) {
|
||||||
|
super(serverURI);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsrWebsocketClient(URI serverUri, Map<String, String> httpHeaders) {
|
||||||
|
super(serverUri, httpHeaders);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(ServerHandshake handshakedata) {
|
||||||
|
|
||||||
|
float[] floats = OnlineRecognizer.readWavFile(AsrWebsocketClient.wavPath);
|
||||||
|
ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(4 * floats.length)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN); // float is sizeof 4. allocate enough buffer
|
||||||
|
|
||||||
|
for (float f : floats) {
|
||||||
|
buffer.putFloat(f);
|
||||||
|
}
|
||||||
|
buffer.rewind();
|
||||||
|
buffer.flip();
|
||||||
|
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
|
||||||
|
send(buffer.array()); // send buf to server
|
||||||
|
send("Done"); // send 'Done' means finished
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onMessage(String message) {
|
||||||
|
|
||||||
|
logger.info("received: " + message);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onClose(int code, String reason, boolean remote) {
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Connection closed by "
|
||||||
|
+ (remote ? "remote peer" : "us")
|
||||||
|
+ " Code: "
|
||||||
|
+ code
|
||||||
|
+ " Reason: "
|
||||||
|
+ reason);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Exception ex) {
|
||||||
|
ex.printStackTrace();
|
||||||
|
// if the error is fatal then onClose will be called additionally
|
||||||
|
}
|
||||||
|
|
||||||
|
public static OnlineRecognizer rcgobj;
|
||||||
|
public static String wavPath;
|
||||||
|
|
||||||
|
public static void main(String[] args) throws URISyntaxException {
|
||||||
|
|
||||||
|
if (args.length != 5) {
|
||||||
|
System.out.println("usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
String soPath = args[0];
|
||||||
|
String srvIp = args[1];
|
||||||
|
String srvPort = args[2];
|
||||||
|
String wavPath = args[3];
|
||||||
|
int numThreads = Integer.parseInt(args[4]);
|
||||||
|
System.out.println("serIp=" + srvIp + ",srvPort=" + srvPort + ",wavPath=" + wavPath);
|
||||||
|
|
||||||
|
class ClientThread implements Runnable {
|
||||||
|
|
||||||
|
String soPath;
|
||||||
|
String srvIp;
|
||||||
|
String srvPort;
|
||||||
|
String wavPath;
|
||||||
|
|
||||||
|
ClientThread(String soPath, String srvIp, String srvPort, String wavPath) {
|
||||||
|
this.soPath = soPath;
|
||||||
|
this.srvIp = srvIp;
|
||||||
|
this.srvPort = srvPort;
|
||||||
|
this.wavPath = wavPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
|
||||||
|
OnlineRecognizer.setSoPath(soPath);
|
||||||
|
|
||||||
|
AsrWebsocketClient.wavPath = wavPath;
|
||||||
|
|
||||||
|
String wsAddress = "ws://" + srvIp + ":" + srvPort;
|
||||||
|
AsrWebsocketClient c = new AsrWebsocketClient(new URI(wsAddress));
|
||||||
|
|
||||||
|
c.connect();
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numThreads; i++) {
|
||||||
|
System.out.println("Thread1 is running...");
|
||||||
|
Thread t = new Thread(new ClientThread(soPath, srvIp, srvPort, wavPath));
|
||||||
|
t.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
248
java-api-examples/src/websocketsrv/AsrWebsocketServer.java
Executable file
248
java-api-examples/src/websocketsrv/AsrWebsocketServer.java
Executable file
@@ -0,0 +1,248 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
// java websocketServer
|
||||||
|
// usage: AsrWebsocketServer soPath modelCfgPath
|
||||||
|
package websocketsrv;
|
||||||
|
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineStream;
|
||||||
|
import java.io.*;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.net.UnknownHostException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
|
import org.java_websocket.WebSocket;
|
||||||
|
import org.java_websocket.drafts.Draft;
|
||||||
|
import org.java_websocket.drafts.Draft_6455;
|
||||||
|
import org.java_websocket.handshake.ClientHandshake;
|
||||||
|
import org.java_websocket.server.WebSocketServer;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and
|
||||||
|
* one pool for asr decoder.
|
||||||
|
*/
|
||||||
|
public class AsrWebsocketServer extends WebSocketServer {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class);
|
||||||
|
// Queue between io network io thread pool and stream thread pool, use websocket as the key
|
||||||
|
private LinkedBlockingQueue<WebSocket> streamQueue = new LinkedBlockingQueue<WebSocket>();
|
||||||
|
// Queue waiting for deocdeing, use websocket as the key
|
||||||
|
private LinkedBlockingQueue<WebSocket> decoderQueue = new LinkedBlockingQueue<WebSocket>();
|
||||||
|
|
||||||
|
// recogizer object
|
||||||
|
private OnlineRecognizer rcgOjb = null;
|
||||||
|
|
||||||
|
// mapping between websocket connection and connection data
|
||||||
|
private ConcurrentHashMap<WebSocket, ConnectionData> connectionMap =
|
||||||
|
new ConcurrentHashMap<WebSocket, ConnectionData>();
|
||||||
|
|
||||||
|
public AsrWebsocketServer(int port, int numThread) throws UnknownHostException {
|
||||||
|
// server port and num of threads for network io
|
||||||
|
super(new InetSocketAddress(port), numThread);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsrWebsocketServer(InetSocketAddress address) {
|
||||||
|
super(address);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsrWebsocketServer(int port, Draft_6455 draft) {
|
||||||
|
super(new InetSocketAddress(port), Collections.<Draft>singletonList(draft));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(WebSocket conn, ClientHandshake handshake) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
|
||||||
|
connectionMap.remove(conn);
|
||||||
|
logger.info(
|
||||||
|
conn
|
||||||
|
+ " remove one connection!, now connection number="
|
||||||
|
+ String.valueOf(connectionMap.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onMessage(WebSocket conn, String message) {
|
||||||
|
// this is text message
|
||||||
|
try {
|
||||||
|
// if rec "Done" msg from client
|
||||||
|
if (message.equals("Done")) {
|
||||||
|
ConnectionData connData = creatOrGetConnectionData(conn);
|
||||||
|
connData.setEof(true);
|
||||||
|
if (!streamQueueFind(conn)) {
|
||||||
|
streamQueue.put(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ConnectionData creatOrGetConnectionData(WebSocket conn) {
|
||||||
|
// create a new connection data if not in connection map or return the existed one
|
||||||
|
|
||||||
|
ConnectionData connData = null;
|
||||||
|
try {
|
||||||
|
if (!connectionMap.containsKey(conn)) {
|
||||||
|
OnlineStream stream = rcgOjb.createStream();
|
||||||
|
connData = new ConnectionData(conn, stream);
|
||||||
|
connectionMap.put(conn, connData);
|
||||||
|
} else {
|
||||||
|
connData = connectionMap.get(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
conn.getRemoteSocketAddress().getAddress().getHostAddress()
|
||||||
|
+ " open one connection,, now connection number="
|
||||||
|
+ String.valueOf(connectionMap.size()));
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println(e);
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
return connData;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onMessage(WebSocket conn, ByteBuffer blob) {
|
||||||
|
try {
|
||||||
|
|
||||||
|
// for handle binary data
|
||||||
|
blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian
|
||||||
|
|
||||||
|
// set to float
|
||||||
|
FloatBuffer floatbuf = blob.asFloatBuffer();
|
||||||
|
|
||||||
|
if (floatbuf.capacity() > 0) {
|
||||||
|
// allocate memory for float data
|
||||||
|
float[] arr = new float[floatbuf.capacity()];
|
||||||
|
|
||||||
|
floatbuf.get(arr);
|
||||||
|
ConnectionData connData = creatOrGetConnectionData(conn);
|
||||||
|
// put websocket to stream queue with binary type==1
|
||||||
|
connData.addSamplesToData(arr);
|
||||||
|
|
||||||
|
if (!streamQueueFind(conn)) {
|
||||||
|
streamQueue.put(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public boolean streamQueueFind(WebSocket conn) {
|
||||||
|
return streamQueue.contains(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) {
|
||||||
|
try {
|
||||||
|
|
||||||
|
rcgOjb = new OnlineRecognizer(cfgPath);
|
||||||
|
// size of stream thread pool
|
||||||
|
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
|
||||||
|
// size of decoder thread pool
|
||||||
|
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
|
||||||
|
|
||||||
|
// time(ms) idle for decoder thread when no job
|
||||||
|
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
|
||||||
|
// size of streams for parallel decoding
|
||||||
|
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
|
||||||
|
// time(ms) out for connection data
|
||||||
|
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
|
||||||
|
|
||||||
|
// create stream threads
|
||||||
|
for (int i = 0; i < streamThreadNum; i++) {
|
||||||
|
new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start();
|
||||||
|
}
|
||||||
|
// create decoder threads
|
||||||
|
for (int i = 0; i < decoderThreadNum; i++) {
|
||||||
|
new DecoderThreadHandler(
|
||||||
|
decoderQueue,
|
||||||
|
connectionMap,
|
||||||
|
rcgOjb,
|
||||||
|
decoderTimeIdle,
|
||||||
|
parallelDecoderNum,
|
||||||
|
deocderTimeOut)
|
||||||
|
.start();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.err.println(e);
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, String> readProperties(String CfgPath) {
|
||||||
|
// read and parse config file
|
||||||
|
Properties props = new Properties();
|
||||||
|
Map<String, String> proMap = new HashMap<String, String>();
|
||||||
|
try {
|
||||||
|
|
||||||
|
File file = new File(CfgPath);
|
||||||
|
if (!file.exists()) {
|
||||||
|
logger.info(String.valueOf(CfgPath) + " cfg file not exists!");
|
||||||
|
System.exit(0);
|
||||||
|
}
|
||||||
|
InputStream in = new BufferedInputStream(new FileInputStream(CfgPath));
|
||||||
|
props.load(in);
|
||||||
|
Enumeration en = props.propertyNames();
|
||||||
|
while (en.hasMoreElements()) {
|
||||||
|
String key = (String) en.nextElement();
|
||||||
|
String Property = props.getProperty(key);
|
||||||
|
proMap.put(key, Property);
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
return proMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws InterruptedException, IOException {
|
||||||
|
if (args.length != 2) {
|
||||||
|
logger.info("usage: AsrWebsocketServer soPath modelCfgPath");
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
String soPath = args[0];
|
||||||
|
String cfgPath = args[1];
|
||||||
|
|
||||||
|
OnlineRecognizer.setSoPath(soPath);
|
||||||
|
|
||||||
|
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
|
||||||
|
int port = Integer.valueOf(cfgMap.get("port"));
|
||||||
|
|
||||||
|
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
|
||||||
|
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
|
||||||
|
s.initModelWithCfg(cfgMap, cfgPath);
|
||||||
|
logger.info("Server started on port: " + s.getPort());
|
||||||
|
s.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(WebSocket conn, Exception ex) {
|
||||||
|
ex.printStackTrace();
|
||||||
|
if (conn != null) {
|
||||||
|
// some errors like port binding failed may not be assignable to a specific websocket
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onStart() {
|
||||||
|
logger.info("Server started!");
|
||||||
|
setConnectionLostTimeout(0);
|
||||||
|
setConnectionLostTimeout(100);
|
||||||
|
}
|
||||||
|
}
|
||||||
65
java-api-examples/src/websocketsrv/ConnectionData.java
Executable file
65
java-api-examples/src/websocketsrv/ConnectionData.java
Executable file
@@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
// connection data act as a bridge between different threads pools
|
||||||
|
|
||||||
|
package websocketsrv;
|
||||||
|
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineStream;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.Queue;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
import org.java_websocket.WebSocket;
|
||||||
|
|
||||||
|
public class ConnectionData {
|
||||||
|
|
||||||
|
private WebSocket webSocket; // the websocket for this connection data
|
||||||
|
|
||||||
|
private OnlineStream stream; // connection stream
|
||||||
|
|
||||||
|
private Queue<float[]> queueSamples =
|
||||||
|
new LinkedList<float[]>(); // binary data rec from the client
|
||||||
|
|
||||||
|
private boolean eof = false; // connection data is done
|
||||||
|
|
||||||
|
private LocalDateTime lastHandleTime; // used for time out in ms
|
||||||
|
|
||||||
|
public ConnectionData(WebSocket webSocket, OnlineStream stream) {
|
||||||
|
this.webSocket = webSocket;
|
||||||
|
|
||||||
|
this.stream = stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addSamplesToData(float[] samples) {
|
||||||
|
this.queueSamples.add(samples);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LocalDateTime getLastHandleTime() {
|
||||||
|
return this.lastHandleTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLastHandleTime(LocalDateTime now) {
|
||||||
|
this.lastHandleTime = now;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean getEof() {
|
||||||
|
return this.eof;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEof(boolean eof) {
|
||||||
|
this.eof = eof;
|
||||||
|
}
|
||||||
|
|
||||||
|
public WebSocket getWebSocket() {
|
||||||
|
return this.webSocket;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Queue<float[]> getQueueSamples() {
|
||||||
|
return this.queueSamples;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OnlineStream getStream() {
|
||||||
|
return this.stream;
|
||||||
|
}
|
||||||
|
}
|
||||||
173
java-api-examples/src/websocketsrv/DecoderThreadHandler.java
Executable file
173
java-api-examples/src/websocketsrv/DecoderThreadHandler.java
Executable file
@@ -0,0 +1,173 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
// java DecoderThreadHandler
|
||||||
|
package websocketsrv;
|
||||||
|
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineStream;
|
||||||
|
import java.nio.*;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
|
import org.java_websocket.WebSocket;
|
||||||
|
import org.java_websocket.drafts.Draft;
|
||||||
|
import org.java_websocket.framing.Framedata;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class DecoderThreadHandler extends Thread {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(DecoderThreadHandler.class);
|
||||||
|
// Websocket Queue that waiting for decoding
|
||||||
|
private LinkedBlockingQueue<WebSocket> decoderQueue;
|
||||||
|
// the mapping between websocket and connection data
|
||||||
|
private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
|
||||||
|
|
||||||
|
private OnlineRecognizer rcgOjb = null; // recgnizer object
|
||||||
|
|
||||||
|
// connection data list for this thread to decode in parallel
|
||||||
|
private List<ConnectionData> connDataList = new ArrayList<ConnectionData>();
|
||||||
|
|
||||||
|
private int parallelDecoderNum = 10; // parallel decoding number
|
||||||
|
private int deocderTimeIdle = 10; // idle time(ms) when no job
|
||||||
|
private int deocderTimeOut = 3000; // if it is timeout(ms), the connection data will be removed
|
||||||
|
|
||||||
|
public DecoderThreadHandler(
|
||||||
|
LinkedBlockingQueue<WebSocket> decoderQueue,
|
||||||
|
ConcurrentHashMap<WebSocket, ConnectionData> connMap,
|
||||||
|
OnlineRecognizer rcgOjb,
|
||||||
|
int deocderTimeIdle,
|
||||||
|
int parallelDecoderNum,
|
||||||
|
int deocderTimeOut) {
|
||||||
|
this.decoderQueue = decoderQueue;
|
||||||
|
this.connMap = connMap;
|
||||||
|
this.rcgOjb = rcgOjb;
|
||||||
|
this.deocderTimeIdle = deocderTimeIdle;
|
||||||
|
this.parallelDecoderNum = parallelDecoderNum;
|
||||||
|
this.deocderTimeOut = deocderTimeOut;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run() {
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
// time(ms) idle if there is no job
|
||||||
|
|
||||||
|
Thread.sleep(deocderTimeIdle);
|
||||||
|
// clear data list for this threads
|
||||||
|
connDataList.clear();
|
||||||
|
if (rcgOjb == null) continue;
|
||||||
|
|
||||||
|
// loop for total decoder Queue
|
||||||
|
while (!decoderQueue.isEmpty()) {
|
||||||
|
|
||||||
|
// get websocket
|
||||||
|
WebSocket conn = decoderQueue.take();
|
||||||
|
// get connection data according to websocket
|
||||||
|
ConnectionData connData = connMap.get(conn);
|
||||||
|
|
||||||
|
// if the websocket closed, continue
|
||||||
|
if (connData == null) continue;
|
||||||
|
// get the stream
|
||||||
|
OnlineStream stream = connData.getStream();
|
||||||
|
|
||||||
|
// put to decoder list if 1) stream is ready; 2) and
|
||||||
|
// size not > parallelDecoderNum
|
||||||
|
if ((rcgOjb.isReady(stream) && connDataList.size() < parallelDecoderNum)) {
|
||||||
|
|
||||||
|
// add to this thread's decoder list
|
||||||
|
connDataList.add(connData);
|
||||||
|
// change the handled time for this connection data
|
||||||
|
connData.setLastHandleTime(LocalDateTime.now());
|
||||||
|
}
|
||||||
|
// break when decoder list size >= parallelDecoderNum
|
||||||
|
if (connDataList.size() >= parallelDecoderNum) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if decoder data list for this thread >0
|
||||||
|
if (connDataList.size() > 0) {
|
||||||
|
|
||||||
|
// create a stream array for parallel decoding
|
||||||
|
OnlineStream[] arr = new OnlineStream[connDataList.size()];
|
||||||
|
for (int i = 0; i < connDataList.size(); i++) {
|
||||||
|
|
||||||
|
arr[i] = connDataList.get(i).getStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
// parallel decoding
|
||||||
|
rcgOjb.decodeStreams(arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// get result for each connection
|
||||||
|
for (ConnectionData connData : connDataList) {
|
||||||
|
|
||||||
|
OnlineStream stream = connData.getStream();
|
||||||
|
WebSocket webSocket = connData.getWebSocket();
|
||||||
|
|
||||||
|
String txtResult = rcgOjb.getResult(stream);
|
||||||
|
|
||||||
|
// decode text in utf-8
|
||||||
|
byte[] utf8Data = txtResult.getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
boolean isEof = (connData.getEof() == true && !rcgOjb.isReady(stream));
|
||||||
|
// result
|
||||||
|
if (utf8Data.length > 0) {
|
||||||
|
|
||||||
|
String jsonResult =
|
||||||
|
"{\"text\":\"" + txtResult + "\",\"eof\":" + String.valueOf(isEof) + "\"}";
|
||||||
|
|
||||||
|
if (webSocket.isOpen()) {
|
||||||
|
// create a TEXT Frame for send back json result
|
||||||
|
Draft draft = webSocket.getDraft();
|
||||||
|
List<Framedata> frames = null;
|
||||||
|
frames = draft.createFrames(jsonResult, false);
|
||||||
|
// send to client
|
||||||
|
webSocket.sendFrame(frames);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// loop for each connection data in this thread
|
||||||
|
for (ConnectionData connData : connDataList) {
|
||||||
|
OnlineStream stream = connData.getStream();
|
||||||
|
WebSocket webSocket = connData.getWebSocket();
|
||||||
|
// if the stream is still ready, put it to decoder Queue again for next decoding
|
||||||
|
if (rcgOjb.isReady(stream)) {
|
||||||
|
decoderQueue.put(webSocket);
|
||||||
|
}
|
||||||
|
// the duration between last handled time and now
|
||||||
|
java.time.Duration duration =
|
||||||
|
java.time.Duration.between(connData.getLastHandleTime(), LocalDateTime.now());
|
||||||
|
// close the websocket if 1) data is done and stream not ready; 2) or data is time out;
|
||||||
|
// 3) or
|
||||||
|
// connection is closed
|
||||||
|
if ((connData.getEof() == true
|
||||||
|
&& !rcgOjb.isReady(stream)
|
||||||
|
&& connData.getQueueSamples().isEmpty())
|
||||||
|
|| duration.toMillis() > deocderTimeOut
|
||||||
|
|| !connData.getWebSocket().isOpen()) {
|
||||||
|
|
||||||
|
logger.info("close websocket!!!");
|
||||||
|
|
||||||
|
// delay close web socket as data may still in processing
|
||||||
|
Timer timer = new Timer();
|
||||||
|
timer.schedule(
|
||||||
|
new TimerTask() {
|
||||||
|
public void run() {
|
||||||
|
|
||||||
|
webSocket.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
5000); // 5 seconds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
67
java-api-examples/src/websocketsrv/StreamThreadHandler.java
Executable file
67
java-api-examples/src/websocketsrv/StreamThreadHandler.java
Executable file
@@ -0,0 +1,67 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
// java StreamThreadHandler
|
||||||
|
package websocketsrv;
|
||||||
|
|
||||||
|
import com.k2fsa.sherpa.onnx.OnlineStream;
|
||||||
|
import java.nio.*;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
|
import org.java_websocket.WebSocket;
|
||||||
|
// thread for processing stream
|
||||||
|
|
||||||
|
public class StreamThreadHandler extends Thread {
|
||||||
|
// Queue between io network io thread pool and stream thread pool, use websocket as the key
|
||||||
|
private LinkedBlockingQueue<WebSocket> streamQueue;
|
||||||
|
// Queue waiting for deocdeing, use websocket as the key
|
||||||
|
private LinkedBlockingQueue<WebSocket> decoderQueue;
|
||||||
|
// mapping between websocket connection and connection data
|
||||||
|
private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
|
||||||
|
|
||||||
|
public StreamThreadHandler(
|
||||||
|
LinkedBlockingQueue<WebSocket> streamQueue,
|
||||||
|
LinkedBlockingQueue<WebSocket> decoderQueue,
|
||||||
|
ConcurrentHashMap<WebSocket, ConnectionData> connMap) {
|
||||||
|
this.streamQueue = streamQueue;
|
||||||
|
this.decoderQueue = decoderQueue;
|
||||||
|
this.connMap = connMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run() {
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
// fetch one websocket from queue
|
||||||
|
WebSocket conn = (WebSocket) this.streamQueue.take();
|
||||||
|
// get the connection data according to websocket
|
||||||
|
ConnectionData connData = connMap.get(conn);
|
||||||
|
OnlineStream stream = connData.getStream();
|
||||||
|
|
||||||
|
// handle received binary data
|
||||||
|
if (!connData.getQueueSamples().isEmpty()) {
|
||||||
|
// loop to put all received binary data to stream
|
||||||
|
while (!connData.getQueueSamples().isEmpty()) {
|
||||||
|
|
||||||
|
float[] samples = connData.getQueueSamples().poll();
|
||||||
|
|
||||||
|
stream.acceptWaveform(samples);
|
||||||
|
}
|
||||||
|
// if data is finished
|
||||||
|
if (connData.getEof() == true) {
|
||||||
|
|
||||||
|
stream.inputFinished();
|
||||||
|
}
|
||||||
|
// add this websocket to decoder Queue if not in the Queue
|
||||||
|
if (!decoderQueue.contains(conn)) {
|
||||||
|
|
||||||
|
decoderQueue.put(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
/*
|
||||||
|
* // Copyright 2022-2023 by zhaoming
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.k2fsa.sherpa.onnx;
|
||||||
|
|
||||||
|
public class OnlineLMConfig {
|
||||||
|
private final String model;
|
||||||
|
private final float scale;
|
||||||
|
|
||||||
|
public OnlineLMConfig(String model, float scale) {
|
||||||
|
this.model = model;
|
||||||
|
this.scale = scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getModel() {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getScale() {
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -65,11 +65,14 @@ public class OnlineRecognizer {
|
|||||||
false);
|
false);
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
OnlineRecognizerConfig rcgCfg =
|
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||||
|
|
||||||
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig,
|
featConfig,
|
||||||
modelCfg,
|
modelCfg,
|
||||||
endCfg,
|
endCfg,
|
||||||
|
onlineLmConfig,
|
||||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||||
proMap.get("decoding_method").trim(),
|
proMap.get("decoding_method").trim(),
|
||||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||||
@@ -107,11 +110,15 @@ public class OnlineRecognizer {
|
|||||||
false);
|
false);
|
||||||
FeatureConfig featConfig =
|
FeatureConfig featConfig =
|
||||||
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
|
||||||
OnlineRecognizerConfig rcgCfg =
|
|
||||||
|
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
|
||||||
|
|
||||||
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig,
|
featConfig,
|
||||||
modelCfg,
|
modelCfg,
|
||||||
endCfg,
|
endCfg,
|
||||||
|
onlineLmConfig,
|
||||||
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
|
||||||
proMap.get("decoding_method").trim(),
|
proMap.get("decoding_method").trim(),
|
||||||
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
Integer.parseInt(proMap.get("max_active_paths").trim()));
|
||||||
@@ -137,6 +144,8 @@ public class OnlineRecognizer {
|
|||||||
float rule2MinTrailingSilence,
|
float rule2MinTrailingSilence,
|
||||||
float rule3MinUtteranceLength,
|
float rule3MinUtteranceLength,
|
||||||
String decodingMethod,
|
String decodingMethod,
|
||||||
|
String lm_model,
|
||||||
|
float lm_scale,
|
||||||
int maxActivePaths) {
|
int maxActivePaths) {
|
||||||
this.sampleRate = sampleRate;
|
this.sampleRate = sampleRate;
|
||||||
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
|
||||||
@@ -146,14 +155,10 @@ public class OnlineRecognizer {
|
|||||||
OnlineTransducerModelConfig modelCfg =
|
OnlineTransducerModelConfig modelCfg =
|
||||||
new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
|
new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
|
||||||
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
|
||||||
OnlineRecognizerConfig rcgCfg =
|
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);
|
||||||
|
OnlineRecognizerConfig rcgCfg =
|
||||||
new OnlineRecognizerConfig(
|
new OnlineRecognizerConfig(
|
||||||
featConfig,
|
featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths);
|
||||||
modelCfg,
|
|
||||||
endCfg,
|
|
||||||
enableEndpointDetection,
|
|
||||||
decodingMethod,
|
|
||||||
maxActivePaths);
|
|
||||||
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
|
||||||
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
|
||||||
}
|
}
|
||||||
@@ -241,7 +246,7 @@ public class OnlineRecognizer {
|
|||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float[] readWavFile(String fileName) {
|
public static float[] readWavFile(String fileName) {
|
||||||
// read data from the filename
|
// read data from the filename
|
||||||
Object[] wavdata = readWave(fileName);
|
Object[] wavdata = readWave(fileName);
|
||||||
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
|
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
|
||||||
@@ -281,7 +286,7 @@ public class OnlineRecognizer {
|
|||||||
}
|
}
|
||||||
// JNI interface libsherpa-onnx-jni.so
|
// JNI interface libsherpa-onnx-jni.so
|
||||||
|
|
||||||
private native Object[] readWave(String fileName);
|
private static native Object[] readWave(String fileName); // static
|
||||||
|
|
||||||
private native String getResult(long ptr, long streamPtr);
|
private native String getResult(long ptr, long streamPtr);
|
||||||
|
|
||||||
|
|||||||
@@ -8,25 +8,33 @@ public class OnlineRecognizerConfig {
|
|||||||
private final FeatureConfig featConfig;
|
private final FeatureConfig featConfig;
|
||||||
private final OnlineTransducerModelConfig modelConfig;
|
private final OnlineTransducerModelConfig modelConfig;
|
||||||
private final EndpointConfig endpointConfig;
|
private final EndpointConfig endpointConfig;
|
||||||
|
private final OnlineLMConfig lmConfig;
|
||||||
private final boolean enableEndpoint;
|
private final boolean enableEndpoint;
|
||||||
private final String decodingMethod;
|
private final String decodingMethod;
|
||||||
private final int maxActivePaths;
|
private final int maxActivePaths;
|
||||||
|
|
||||||
|
|
||||||
public OnlineRecognizerConfig(
|
public OnlineRecognizerConfig(
|
||||||
FeatureConfig featConfig,
|
FeatureConfig featConfig,
|
||||||
OnlineTransducerModelConfig modelConfig,
|
OnlineTransducerModelConfig modelConfig,
|
||||||
EndpointConfig endpointConfig,
|
EndpointConfig endpointConfig,
|
||||||
|
OnlineLMConfig lmConfig,
|
||||||
boolean enableEndpoint,
|
boolean enableEndpoint,
|
||||||
String decodingMethod,
|
String decodingMethod,
|
||||||
int maxActivePaths) {
|
int maxActivePaths) {
|
||||||
this.featConfig = featConfig;
|
this.featConfig = featConfig;
|
||||||
this.modelConfig = modelConfig;
|
this.modelConfig = modelConfig;
|
||||||
this.endpointConfig = endpointConfig;
|
this.endpointConfig = endpointConfig;
|
||||||
|
this.lmConfig = lmConfig;
|
||||||
this.enableEndpoint = enableEndpoint;
|
this.enableEndpoint = enableEndpoint;
|
||||||
this.decodingMethod = decodingMethod;
|
this.decodingMethod = decodingMethod;
|
||||||
this.maxActivePaths = maxActivePaths;
|
this.maxActivePaths = maxActivePaths;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OnlineLMConfig getLmConfig() {
|
||||||
|
return lmConfig;
|
||||||
|
}
|
||||||
|
|
||||||
public FeatureConfig getFeatConfig() {
|
public FeatureConfig getFeatConfig() {
|
||||||
return featConfig;
|
return featConfig;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user