diff --git a/java-api-examples/Makefile b/java-api-examples/Makefile index cdc17503..3404b1e3 100755 --- a/java-api-examples/Makefile +++ b/java-api-examples/Makefile @@ -7,11 +7,21 @@ LIB_FILES = \ $(LIB_SRC_DIR)/EndpointRule.java \ $(LIB_SRC_DIR)/EndpointConfig.java \ $(LIB_SRC_DIR)/FeatureConfig.java \ + $(LIB_SRC_DIR)/OnlineLMConfig.java \ $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ $(LIB_SRC_DIR)/OnlineStream.java \ $(LIB_SRC_DIR)/OnlineRecognizer.java \ +WEBSOCKET_DIR:= ./src/websocketsrv +WEBSOCKET_FILES = \ + $(WEBSOCKET_DIR)/ConnectionData.java \ + $(WEBSOCKET_DIR)/DecoderThreadHandler.java \ + $(WEBSOCKET_DIR)/StreamThreadHandler.java \ + $(WEBSOCKET_DIR)/AsrWebsocketServer.java \ + $(WEBSOCKET_DIR)/AsrWebsocketClient.java \ + + LIB_BUILD_DIR = ./lib @@ -39,7 +49,13 @@ buildmic: rebuild: clean all -.PHONY: clean run +.PHONY: clean run downjar + +downjar: + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/ + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/ + wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/ + clean: rm -frv $(BUILD_DIR)/* @@ -56,6 +72,12 @@ runmic: java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic +runsrv: + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so ./modelconfig.cfg + +runclient: + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 + buildlib: $(LIB_FILES:.java=.class) @@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class) $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< +buildwebsocket: $(WEBSOCKET_FILES:.java=.class) + + +%.class: %.java + + $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< + packjar: jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . -all: clean buildlib packjar buildfile buildmic +all: clean buildlib packjar buildfile buildmic downjar buildwebsocket + + diff --git a/java-api-examples/README.md b/java-api-examples/README.md index 8f1bc5af..97968cc8 100755 --- a/java-api-examples/README.md +++ b/java-api-examples/README.md @@ -2,6 +2,7 @@ -------------- Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform. +now support multiple threads for websocket server ```xml Depend on: @@ -35,10 +36,10 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: 3.Config model config.cfg ------------------------- - +/**change model path in config.cfg according to your env**/ ```xml - #model config - sample_rate=16000 + #model config + sample_rate=16000 feature_dim=80 rule1_min_trailing_silence=2.4 rule2_min_trailing_silence=1.2 @@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: enable_endpoint_detection=false decoding_method=greedy_search max_active_paths=4 + + #websocket server config + port=8890 + #number of threads pool for network io + connection_thread_num=16 + #number of threads pool for stream + stream_thread_num=16 + #number of threads pool for decoder + decoder_thread_num=16 + #size of streams for parallel decoding + parallel_decoder_num=16 + #time(ms) idle for decoder thread when no job + decoder_time_idle=10 + #time(ms) out for connection data + deocder_time_out=3000 ``` --- @@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar make runmic ``` +--- +6.WebSocket Server +---------- + +support multiple threads for websocket server +6.0 Protocol for communication +1) client connect to server +```shell + ws client -> srv ws address + ws address example: ws://127.0.0.1:8889/ +``` +2) client send 16k pcm_s16le binary stream data to server +```shell + PCM sampleRate 16000 + single channel + sampleSize 16bit + little endian + type short +``` +3) client send "Done" text to server when all data is sent +```shell + ws_socket.send("Done") +``` +4) client will receive json message from server whenever asr engine decoded new text +```shell + json example: {"text":"甚至出现交易几乎停滞的情况","eof":false"} +``` + + +6.1 Build + +```bash + cd sherpa-onnx/java-api-examples + make all +``` + +6.2 Run srv example + +usage: AsrWebsocketServer soPath modelCfgPath + +```bash + make runsrv /**change path in Makefile according to your env**/ +``` + +6.3 Run multiple threads client example + +usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads + +json result example: {"text":"甚至出现交易几乎停滞的情况","eof":"true"} + +```bash + make runclient /**change path in Makefile according to your env**/ +``` diff --git a/java-api-examples/modelconfig.cfg b/java-api-examples/modelconfig.cfg index 0b09c7e2..2e280778 100755 --- a/java-api-examples/modelconfig.cfg +++ b/java-api-examples/modelconfig.cfg @@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh- joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt num_threads=4 -enable_endpoint_detection=false +enable_endpoint_detection=true decoding_method=modified_beam_search -max_active_paths=4 \ No newline at end of file +max_active_paths=4 +lm_model= +lm_scale=0.5 + +#websocket server config +port=8890 +connection_thread_num=16 +stream_thread_num=16 +decoder_thread_num=16 +parallel_decoder_num=16 +decoder_time_idle=200 +deocder_time_out=30000 diff --git a/java-api-examples/src/DecodeFile.java b/java-api-examples/src/DecodeFile.java index c5849703..afbe3365 100644 --- a/java-api-examples/src/DecodeFile.java +++ b/java-api-examples/src/DecodeFile.java @@ -49,7 +49,8 @@ public class DecodeFile { float rule3MinUtteranceLength = 20F; String decodingMethod = "greedy_search"; int maxActivePaths = 4; - + String lm_model=""; + float lm_scale=0.5F; rcgOjb = new OnlineRecognizer( tokens, @@ -64,6 +65,8 @@ public class DecodeFile { rule2MinTrailingSilence, rule3MinUtteranceLength, decodingMethod, + lm_model, + lm_scale, maxActivePaths); streamObj = rcgOjb.createStream(); } catch (Exception e) { diff --git a/java-api-examples/src/websocketsrv/AsrWebsocketClient.java b/java-api-examples/src/websocketsrv/AsrWebsocketClient.java new file mode 100755 index 00000000..efa4245b --- /dev/null +++ b/java-api-examples/src/websocketsrv/AsrWebsocketClient.java @@ -0,0 +1,131 @@ +/* + * // Copyright 2022-2023 by zhaomingwork + */ +// java AsrWebsocketClient +// usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads +package websocketsrv; + +import com.k2fsa.sherpa.onnx.OnlineRecognizer; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.*; +import java.util.Map; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.drafts.Draft; +import org.java_websocket.handshake.ServerHandshake; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** This example demonstrates how to connect to websocket server. */ +public class AsrWebsocketClient extends WebSocketClient { + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketClient.class); + + public AsrWebsocketClient(URI serverUri, Draft draft) { + super(serverUri, draft); + } + + public AsrWebsocketClient(URI serverURI) { + super(serverURI); + } + + public AsrWebsocketClient(URI serverUri, Map httpHeaders) { + super(serverUri, httpHeaders); + } + + @Override + public void onOpen(ServerHandshake handshakedata) { + + float[] floats = OnlineRecognizer.readWavFile(AsrWebsocketClient.wavPath); + ByteBuffer buffer = + ByteBuffer.allocate(4 * floats.length) + .order(ByteOrder.LITTLE_ENDIAN); // float is sizeof 4. allocate enough buffer + + for (float f : floats) { + buffer.putFloat(f); + } + buffer.rewind(); + buffer.flip(); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + send(buffer.array()); // send buf to server + send("Done"); // send 'Done' means finished + } + + @Override + public void onMessage(String message) { + + logger.info("received: " + message); + } + + @Override + public void onClose(int code, String reason, boolean remote) { + + logger.info( + "Connection closed by " + + (remote ? "remote peer" : "us") + + " Code: " + + code + + " Reason: " + + reason); + } + + @Override + public void onError(Exception ex) { + ex.printStackTrace(); + // if the error is fatal then onClose will be called additionally + } + + public static OnlineRecognizer rcgobj; + public static String wavPath; + + public static void main(String[] args) throws URISyntaxException { + + if (args.length != 5) { + System.out.println("usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads"); + return; + } + + String soPath = args[0]; + String srvIp = args[1]; + String srvPort = args[2]; + String wavPath = args[3]; + int numThreads = Integer.parseInt(args[4]); + System.out.println("serIp=" + srvIp + ",srvPort=" + srvPort + ",wavPath=" + wavPath); + + class ClientThread implements Runnable { + + String soPath; + String srvIp; + String srvPort; + String wavPath; + + ClientThread(String soPath, String srvIp, String srvPort, String wavPath) { + this.soPath = soPath; + this.srvIp = srvIp; + this.srvPort = srvPort; + this.wavPath = wavPath; + } + + public void run() { + try { + + OnlineRecognizer.setSoPath(soPath); + + AsrWebsocketClient.wavPath = wavPath; + + String wsAddress = "ws://" + srvIp + ":" + srvPort; + AsrWebsocketClient c = new AsrWebsocketClient(new URI(wsAddress)); + + c.connect(); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + for (int i = 0; i < numThreads; i++) { + System.out.println("Thread1 is running..."); + Thread t = new Thread(new ClientThread(soPath, srvIp, srvPort, wavPath)); + t.start(); + } + } +} diff --git a/java-api-examples/src/websocketsrv/AsrWebsocketServer.java b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java new file mode 100755 index 00000000..879e9a9f --- /dev/null +++ b/java-api-examples/src/websocketsrv/AsrWebsocketServer.java @@ -0,0 +1,248 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ +// java websocketServer +// usage: AsrWebsocketServer soPath modelCfgPath +package websocketsrv; + +import com.k2fsa.sherpa.onnx.OnlineRecognizer; +import com.k2fsa.sherpa.onnx.OnlineStream; +import java.io.*; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.*; +import java.util.Collections; +import java.util.concurrent.*; +import java.util.concurrent.LinkedBlockingQueue; +import org.java_websocket.WebSocket; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.handshake.ClientHandshake; +import org.java_websocket.server.WebSocketServer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and + * one pool for asr decoder. + */ +public class AsrWebsocketServer extends WebSocketServer { + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class); + // Queue between io network io thread pool and stream thread pool, use websocket as the key + private LinkedBlockingQueue streamQueue = new LinkedBlockingQueue(); + // Queue waiting for deocdeing, use websocket as the key + private LinkedBlockingQueue decoderQueue = new LinkedBlockingQueue(); + + // recogizer object + private OnlineRecognizer rcgOjb = null; + + // mapping between websocket connection and connection data + private ConcurrentHashMap connectionMap = + new ConcurrentHashMap(); + + public AsrWebsocketServer(int port, int numThread) throws UnknownHostException { + // server port and num of threads for network io + super(new InetSocketAddress(port), numThread); + } + + public AsrWebsocketServer(InetSocketAddress address) { + super(address); + } + + public AsrWebsocketServer(int port, Draft_6455 draft) { + super(new InetSocketAddress(port), Collections.singletonList(draft)); + } + + @Override + public void onOpen(WebSocket conn, ClientHandshake handshake) {} + + @Override + public void onClose(WebSocket conn, int code, String reason, boolean remote) { + connectionMap.remove(conn); + logger.info( + conn + + " remove one connection!, now connection number=" + + String.valueOf(connectionMap.size())); + } + + @Override + public void onMessage(WebSocket conn, String message) { + // this is text message + try { + // if rec "Done" msg from client + if (message.equals("Done")) { + ConnectionData connData = creatOrGetConnectionData(conn); + connData.setEof(true); + if (!streamQueueFind(conn)) { + streamQueue.put(conn); + } + } + + } catch (Exception e) { + e.printStackTrace(); + } + } + + private ConnectionData creatOrGetConnectionData(WebSocket conn) { + // create a new connection data if not in connection map or return the existed one + + ConnectionData connData = null; + try { + if (!connectionMap.containsKey(conn)) { + OnlineStream stream = rcgOjb.createStream(); + connData = new ConnectionData(conn, stream); + connectionMap.put(conn, connData); + } else { + connData = connectionMap.get(conn); + } + + logger.info( + conn.getRemoteSocketAddress().getAddress().getHostAddress() + + " open one connection,, now connection number=" + + String.valueOf(connectionMap.size())); + + } catch (Exception e) { + System.err.println(e); + e.printStackTrace(); + } + return connData; + } + + @Override + public void onMessage(WebSocket conn, ByteBuffer blob) { + try { + + // for handle binary data + blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian + + // set to float + FloatBuffer floatbuf = blob.asFloatBuffer(); + + if (floatbuf.capacity() > 0) { + // allocate memory for float data + float[] arr = new float[floatbuf.capacity()]; + + floatbuf.get(arr); + ConnectionData connData = creatOrGetConnectionData(conn); + // put websocket to stream queue with binary type==1 + connData.addSamplesToData(arr); + + if (!streamQueueFind(conn)) { + streamQueue.put(conn); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + + + public boolean streamQueueFind(WebSocket conn) { + return streamQueue.contains(conn); + } + + public void initModelWithCfg(Map cfgMap, String cfgPath) { + try { + + rcgOjb = new OnlineRecognizer(cfgPath); + // size of stream thread pool + int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num")); + // size of decoder thread pool + int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num")); + + // time(ms) idle for decoder thread when no job + int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle")); + // size of streams for parallel decoding + int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num")); + // time(ms) out for connection data + int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out")); + + // create stream threads + for (int i = 0; i < streamThreadNum; i++) { + new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start(); + } + // create decoder threads + for (int i = 0; i < decoderThreadNum; i++) { + new DecoderThreadHandler( + decoderQueue, + connectionMap, + rcgOjb, + decoderTimeIdle, + parallelDecoderNum, + deocderTimeOut) + .start(); + } + } catch (Exception e) { + System.err.println(e); + e.printStackTrace(); + } + } + + public static Map readProperties(String CfgPath) { + // read and parse config file + Properties props = new Properties(); + Map proMap = new HashMap(); + try { + + File file = new File(CfgPath); + if (!file.exists()) { + logger.info(String.valueOf(CfgPath) + " cfg file not exists!"); + System.exit(0); + } + InputStream in = new BufferedInputStream(new FileInputStream(CfgPath)); + props.load(in); + Enumeration en = props.propertyNames(); + while (en.hasMoreElements()) { + String key = (String) en.nextElement(); + String Property = props.getProperty(key); + proMap.put(key, Property); + } + + } catch (Exception e) { + e.printStackTrace(); + } + return proMap; + } + + public static void main(String[] args) throws InterruptedException, IOException { + if (args.length != 2) { + logger.info("usage: AsrWebsocketServer soPath modelCfgPath"); + + return; + } + + String soPath = args[0]; + String cfgPath = args[1]; + + OnlineRecognizer.setSoPath(soPath); + + Map cfgMap = AsrWebsocketServer.readProperties(cfgPath); + int port = Integer.valueOf(cfgMap.get("port")); + + int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); + AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); + s.initModelWithCfg(cfgMap, cfgPath); + logger.info("Server started on port: " + s.getPort()); + s.start(); + } + + @Override + public void onError(WebSocket conn, Exception ex) { + ex.printStackTrace(); + if (conn != null) { + // some errors like port binding failed may not be assignable to a specific websocket + } + } + + @Override + public void onStart() { + logger.info("Server started!"); + setConnectionLostTimeout(0); + setConnectionLostTimeout(100); + } +} diff --git a/java-api-examples/src/websocketsrv/ConnectionData.java b/java-api-examples/src/websocketsrv/ConnectionData.java new file mode 100755 index 00000000..a43d9e72 --- /dev/null +++ b/java-api-examples/src/websocketsrv/ConnectionData.java @@ -0,0 +1,65 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ +// connection data act as a bridge between different threads pools + +package websocketsrv; + +import com.k2fsa.sherpa.onnx.OnlineStream; +import java.time.LocalDateTime; +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.*; +import org.java_websocket.WebSocket; + +public class ConnectionData { + + private WebSocket webSocket; // the websocket for this connection data + + private OnlineStream stream; // connection stream + + private Queue queueSamples = + new LinkedList(); // binary data rec from the client + + private boolean eof = false; // connection data is done + + private LocalDateTime lastHandleTime; // used for time out in ms + + public ConnectionData(WebSocket webSocket, OnlineStream stream) { + this.webSocket = webSocket; + + this.stream = stream; + } + + public void addSamplesToData(float[] samples) { + this.queueSamples.add(samples); + } + + public LocalDateTime getLastHandleTime() { + return this.lastHandleTime; + } + + public void setLastHandleTime(LocalDateTime now) { + this.lastHandleTime = now; + } + + public boolean getEof() { + return this.eof; + } + + public void setEof(boolean eof) { + this.eof = eof; + } + + public WebSocket getWebSocket() { + return this.webSocket; + } + + public Queue getQueueSamples() { + return this.queueSamples; + } + + public OnlineStream getStream() { + return this.stream; + } +} diff --git a/java-api-examples/src/websocketsrv/DecoderThreadHandler.java b/java-api-examples/src/websocketsrv/DecoderThreadHandler.java new file mode 100755 index 00000000..c44c9890 --- /dev/null +++ b/java-api-examples/src/websocketsrv/DecoderThreadHandler.java @@ -0,0 +1,173 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ +// java DecoderThreadHandler +package websocketsrv; + +import com.k2fsa.sherpa.onnx.OnlineRecognizer; +import com.k2fsa.sherpa.onnx.OnlineStream; +import java.nio.*; +import java.nio.charset.StandardCharsets; +import java.time.LocalDateTime; +import java.util.*; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.LinkedBlockingQueue; +import org.java_websocket.WebSocket; +import org.java_websocket.drafts.Draft; +import org.java_websocket.framing.Framedata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class DecoderThreadHandler extends Thread { + private static final Logger logger = LoggerFactory.getLogger(DecoderThreadHandler.class); + // Websocket Queue that waiting for decoding + private LinkedBlockingQueue decoderQueue; + // the mapping between websocket and connection data + private ConcurrentHashMap connMap; + + private OnlineRecognizer rcgOjb = null; // recgnizer object + + // connection data list for this thread to decode in parallel + private List connDataList = new ArrayList(); + + private int parallelDecoderNum = 10; // parallel decoding number + private int deocderTimeIdle = 10; // idle time(ms) when no job + private int deocderTimeOut = 3000; // if it is timeout(ms), the connection data will be removed + + public DecoderThreadHandler( + LinkedBlockingQueue decoderQueue, + ConcurrentHashMap connMap, + OnlineRecognizer rcgOjb, + int deocderTimeIdle, + int parallelDecoderNum, + int deocderTimeOut) { + this.decoderQueue = decoderQueue; + this.connMap = connMap; + this.rcgOjb = rcgOjb; + this.deocderTimeIdle = deocderTimeIdle; + this.parallelDecoderNum = parallelDecoderNum; + this.deocderTimeOut = deocderTimeOut; + } + + public void run() { + while (true) { + try { + // time(ms) idle if there is no job + + Thread.sleep(deocderTimeIdle); + // clear data list for this threads + connDataList.clear(); + if (rcgOjb == null) continue; + + // loop for total decoder Queue + while (!decoderQueue.isEmpty()) { + + // get websocket + WebSocket conn = decoderQueue.take(); + // get connection data according to websocket + ConnectionData connData = connMap.get(conn); + + // if the websocket closed, continue + if (connData == null) continue; + // get the stream + OnlineStream stream = connData.getStream(); + + // put to decoder list if 1) stream is ready; 2) and + // size not > parallelDecoderNum + if ((rcgOjb.isReady(stream) && connDataList.size() < parallelDecoderNum)) { + + // add to this thread's decoder list + connDataList.add(connData); + // change the handled time for this connection data + connData.setLastHandleTime(LocalDateTime.now()); + } + // break when decoder list size >= parallelDecoderNum + if (connDataList.size() >= parallelDecoderNum) { + break; + } + } + + // if decoder data list for this thread >0 + if (connDataList.size() > 0) { + + // create a stream array for parallel decoding + OnlineStream[] arr = new OnlineStream[connDataList.size()]; + for (int i = 0; i < connDataList.size(); i++) { + + arr[i] = connDataList.get(i).getStream(); + } + + // parallel decoding + rcgOjb.decodeStreams(arr); + } + + // get result for each connection + for (ConnectionData connData : connDataList) { + + OnlineStream stream = connData.getStream(); + WebSocket webSocket = connData.getWebSocket(); + + String txtResult = rcgOjb.getResult(stream); + + // decode text in utf-8 + byte[] utf8Data = txtResult.getBytes(StandardCharsets.UTF_8); + + boolean isEof = (connData.getEof() == true && !rcgOjb.isReady(stream)); + // result + if (utf8Data.length > 0) { + + String jsonResult = + "{\"text\":\"" + txtResult + "\",\"eof\":" + String.valueOf(isEof) + "\"}"; + + if (webSocket.isOpen()) { + // create a TEXT Frame for send back json result + Draft draft = webSocket.getDraft(); + List frames = null; + frames = draft.createFrames(jsonResult, false); + // send to client + webSocket.sendFrame(frames); + } + } + } + // loop for each connection data in this thread + for (ConnectionData connData : connDataList) { + OnlineStream stream = connData.getStream(); + WebSocket webSocket = connData.getWebSocket(); + // if the stream is still ready, put it to decoder Queue again for next decoding + if (rcgOjb.isReady(stream)) { + decoderQueue.put(webSocket); + } + // the duration between last handled time and now + java.time.Duration duration = + java.time.Duration.between(connData.getLastHandleTime(), LocalDateTime.now()); + // close the websocket if 1) data is done and stream not ready; 2) or data is time out; + // 3) or + // connection is closed + if ((connData.getEof() == true + && !rcgOjb.isReady(stream) + && connData.getQueueSamples().isEmpty()) + || duration.toMillis() > deocderTimeOut + || !connData.getWebSocket().isOpen()) { + + logger.info("close websocket!!!"); + + // delay close web socket as data may still in processing + Timer timer = new Timer(); + timer.schedule( + new TimerTask() { + public void run() { + + webSocket.close(); + } + }, + 5000); // 5 seconds + } + } + + } catch (Exception e) { + e.printStackTrace(); + } + } + } +} diff --git a/java-api-examples/src/websocketsrv/StreamThreadHandler.java b/java-api-examples/src/websocketsrv/StreamThreadHandler.java new file mode 100755 index 00000000..3550d768 --- /dev/null +++ b/java-api-examples/src/websocketsrv/StreamThreadHandler.java @@ -0,0 +1,67 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ +// java StreamThreadHandler +package websocketsrv; + +import com.k2fsa.sherpa.onnx.OnlineStream; +import java.nio.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.LinkedBlockingQueue; +import org.java_websocket.WebSocket; +// thread for processing stream + +public class StreamThreadHandler extends Thread { + // Queue between io network io thread pool and stream thread pool, use websocket as the key + private LinkedBlockingQueue streamQueue; + // Queue waiting for deocdeing, use websocket as the key + private LinkedBlockingQueue decoderQueue; + // mapping between websocket connection and connection data + private ConcurrentHashMap connMap; + + public StreamThreadHandler( + LinkedBlockingQueue streamQueue, + LinkedBlockingQueue decoderQueue, + ConcurrentHashMap connMap) { + this.streamQueue = streamQueue; + this.decoderQueue = decoderQueue; + this.connMap = connMap; + } + + public void run() { + while (true) { + try { + // fetch one websocket from queue + WebSocket conn = (WebSocket) this.streamQueue.take(); + // get the connection data according to websocket + ConnectionData connData = connMap.get(conn); + OnlineStream stream = connData.getStream(); + + // handle received binary data + if (!connData.getQueueSamples().isEmpty()) { + // loop to put all received binary data to stream + while (!connData.getQueueSamples().isEmpty()) { + + float[] samples = connData.getQueueSamples().poll(); + + stream.acceptWaveform(samples); + } + // if data is finished + if (connData.getEof() == true) { + + stream.inputFinished(); + } + // add this websocket to decoder Queue if not in the Queue + if (!decoderQueue.contains(conn)) { + + decoderQueue.put(conn); + } + } + + } catch (Exception e) { + e.printStackTrace(); + } + } + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java new file mode 100644 index 00000000..7474a299 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java @@ -0,0 +1,23 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class OnlineLMConfig { + private final String model; + private final float scale; + + public OnlineLMConfig(String model, float scale) { + this.model = model; + this.scale = scale; + } + + public String getModel() { + return model; + } + + public float getScale() { + return scale; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 7716fd5a..5658125d 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -65,11 +65,14 @@ public class OnlineRecognizer { false); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); - OnlineRecognizerConfig rcgCfg = + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); + + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( featConfig, modelCfg, endCfg, + onlineLmConfig, Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), proMap.get("decoding_method").trim(), Integer.parseInt(proMap.get("max_active_paths").trim())); @@ -107,11 +110,15 @@ public class OnlineRecognizer { false); FeatureConfig featConfig = new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); - OnlineRecognizerConfig rcgCfg = + + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); + + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( featConfig, modelCfg, endCfg, + onlineLmConfig, Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), proMap.get("decoding_method").trim(), Integer.parseInt(proMap.get("max_active_paths").trim())); @@ -137,6 +144,8 @@ public class OnlineRecognizer { float rule2MinTrailingSilence, float rule3MinUtteranceLength, String decodingMethod, + String lm_model, + float lm_scale, int maxActivePaths) { this.sampleRate = sampleRate; EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); @@ -146,14 +155,10 @@ public class OnlineRecognizer { OnlineTransducerModelConfig modelCfg = new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); - OnlineRecognizerConfig rcgCfg = + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale); + OnlineRecognizerConfig rcgCfg = new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - enableEndpointDetection, - decodingMethod, - maxActivePaths); + featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); } @@ -241,7 +246,7 @@ public class OnlineRecognizer { return stream; } - public float[] readWavFile(String fileName) { + public static float[] readWavFile(String fileName) { // read data from the filename Object[] wavdata = readWave(fileName); Object data = wavdata[0]; // data[0] is float data, data[1] sample rate @@ -281,7 +286,7 @@ public class OnlineRecognizer { } // JNI interface libsherpa-onnx-jni.so - private native Object[] readWave(String fileName); + private static native Object[] readWave(String fileName); // static private native String getResult(long ptr, long streamPtr); diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 3b8e05ec..2b96cb83 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -8,25 +8,33 @@ public class OnlineRecognizerConfig { private final FeatureConfig featConfig; private final OnlineTransducerModelConfig modelConfig; private final EndpointConfig endpointConfig; + private final OnlineLMConfig lmConfig; private final boolean enableEndpoint; private final String decodingMethod; private final int maxActivePaths; + public OnlineRecognizerConfig( FeatureConfig featConfig, OnlineTransducerModelConfig modelConfig, EndpointConfig endpointConfig, + OnlineLMConfig lmConfig, boolean enableEndpoint, String decodingMethod, int maxActivePaths) { this.featConfig = featConfig; this.modelConfig = modelConfig; this.endpointConfig = endpointConfig; + this.lmConfig = lmConfig; this.enableEndpoint = enableEndpoint; this.decodingMethod = decodingMethod; this.maxActivePaths = maxActivePaths; } + public OnlineLMConfig getLmConfig() { + return lmConfig; + } + public FeatureConfig getFeatConfig() { return featConfig; }