Add java websocket support (#137)

* add decode example for mic

* some changes to README.md

* add java websocket srv

* change to readwav to static

* make some changes to code comments

* little change for readme.md

* fix bug about multiple threads

* made little modification

* add protocol in readme, removed static Queue and add lmConfig

---------

Co-authored-by: root <root@localhost.localdomain>
This commit is contained in:
zhaomingwork
2023-05-18 10:35:40 +08:00
committed by GitHub
parent 655c619bf3
commit b70d40f4ab
12 changed files with 853 additions and 19 deletions

View File

@@ -49,7 +49,8 @@ public class DecodeFile {
float rule3MinUtteranceLength = 20F;
String decodingMethod = "greedy_search";
int maxActivePaths = 4;
String lm_model="";
float lm_scale=0.5F;
rcgOjb =
new OnlineRecognizer(
tokens,
@@ -64,6 +65,8 @@ public class DecodeFile {
rule2MinTrailingSilence,
rule3MinUtteranceLength,
decodingMethod,
lm_model,
lm_scale,
maxActivePaths);
streamObj = rcgOjb.createStream();
} catch (Exception e) {

View File

@@ -0,0 +1,131 @@
/*
* // Copyright 2022-2023 by zhaomingwork
*/
// java AsrWebsocketClient
// usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads
package websocketsrv;
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.*;
import java.util.Map;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft;
import org.java_websocket.handshake.ServerHandshake;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** This example demonstrates how to connect to websocket server. */
public class AsrWebsocketClient extends WebSocketClient {
private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketClient.class);
public AsrWebsocketClient(URI serverUri, Draft draft) {
super(serverUri, draft);
}
public AsrWebsocketClient(URI serverURI) {
super(serverURI);
}
public AsrWebsocketClient(URI serverUri, Map<String, String> httpHeaders) {
super(serverUri, httpHeaders);
}
@Override
public void onOpen(ServerHandshake handshakedata) {
float[] floats = OnlineRecognizer.readWavFile(AsrWebsocketClient.wavPath);
ByteBuffer buffer =
ByteBuffer.allocate(4 * floats.length)
.order(ByteOrder.LITTLE_ENDIAN); // float is sizeof 4. allocate enough buffer
for (float f : floats) {
buffer.putFloat(f);
}
buffer.rewind();
buffer.flip();
buffer.order(ByteOrder.LITTLE_ENDIAN);
send(buffer.array()); // send buf to server
send("Done"); // send 'Done' means finished
}
@Override
public void onMessage(String message) {
logger.info("received: " + message);
}
@Override
public void onClose(int code, String reason, boolean remote) {
logger.info(
"Connection closed by "
+ (remote ? "remote peer" : "us")
+ " Code: "
+ code
+ " Reason: "
+ reason);
}
@Override
public void onError(Exception ex) {
ex.printStackTrace();
// if the error is fatal then onClose will be called additionally
}
public static OnlineRecognizer rcgobj;
public static String wavPath;
public static void main(String[] args) throws URISyntaxException {
if (args.length != 5) {
System.out.println("usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads");
return;
}
String soPath = args[0];
String srvIp = args[1];
String srvPort = args[2];
String wavPath = args[3];
int numThreads = Integer.parseInt(args[4]);
System.out.println("serIp=" + srvIp + ",srvPort=" + srvPort + ",wavPath=" + wavPath);
class ClientThread implements Runnable {
String soPath;
String srvIp;
String srvPort;
String wavPath;
ClientThread(String soPath, String srvIp, String srvPort, String wavPath) {
this.soPath = soPath;
this.srvIp = srvIp;
this.srvPort = srvPort;
this.wavPath = wavPath;
}
public void run() {
try {
OnlineRecognizer.setSoPath(soPath);
AsrWebsocketClient.wavPath = wavPath;
String wsAddress = "ws://" + srvIp + ":" + srvPort;
AsrWebsocketClient c = new AsrWebsocketClient(new URI(wsAddress));
c.connect();
} catch (Exception e) {
e.printStackTrace();
}
}
}
for (int i = 0; i < numThreads; i++) {
System.out.println("Thread1 is running...");
Thread t = new Thread(new ClientThread(soPath, srvIp, srvPort, wavPath));
t.start();
}
}
}

View File

@@ -0,0 +1,248 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
// java websocketServer
// usage: AsrWebsocketServer soPath modelCfgPath
package websocketsrv;
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
import com.k2fsa.sherpa.onnx.OnlineStream;
import java.io.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.Collections;
import java.util.concurrent.*;
import java.util.concurrent.LinkedBlockingQueue;
import org.java_websocket.WebSocket;
import org.java_websocket.drafts.Draft;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and
* one pool for asr decoder.
*/
public class AsrWebsocketServer extends WebSocketServer {
private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class);
// Queue between io network io thread pool and stream thread pool, use websocket as the key
private LinkedBlockingQueue<WebSocket> streamQueue = new LinkedBlockingQueue<WebSocket>();
// Queue waiting for deocdeing, use websocket as the key
private LinkedBlockingQueue<WebSocket> decoderQueue = new LinkedBlockingQueue<WebSocket>();
// recogizer object
private OnlineRecognizer rcgOjb = null;
// mapping between websocket connection and connection data
private ConcurrentHashMap<WebSocket, ConnectionData> connectionMap =
new ConcurrentHashMap<WebSocket, ConnectionData>();
public AsrWebsocketServer(int port, int numThread) throws UnknownHostException {
// server port and num of threads for network io
super(new InetSocketAddress(port), numThread);
}
public AsrWebsocketServer(InetSocketAddress address) {
super(address);
}
public AsrWebsocketServer(int port, Draft_6455 draft) {
super(new InetSocketAddress(port), Collections.<Draft>singletonList(draft));
}
@Override
public void onOpen(WebSocket conn, ClientHandshake handshake) {}
@Override
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
connectionMap.remove(conn);
logger.info(
conn
+ " remove one connection!, now connection number="
+ String.valueOf(connectionMap.size()));
}
@Override
public void onMessage(WebSocket conn, String message) {
// this is text message
try {
// if rec "Done" msg from client
if (message.equals("Done")) {
ConnectionData connData = creatOrGetConnectionData(conn);
connData.setEof(true);
if (!streamQueueFind(conn)) {
streamQueue.put(conn);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
private ConnectionData creatOrGetConnectionData(WebSocket conn) {
// create a new connection data if not in connection map or return the existed one
ConnectionData connData = null;
try {
if (!connectionMap.containsKey(conn)) {
OnlineStream stream = rcgOjb.createStream();
connData = new ConnectionData(conn, stream);
connectionMap.put(conn, connData);
} else {
connData = connectionMap.get(conn);
}
logger.info(
conn.getRemoteSocketAddress().getAddress().getHostAddress()
+ " open one connection,, now connection number="
+ String.valueOf(connectionMap.size()));
} catch (Exception e) {
System.err.println(e);
e.printStackTrace();
}
return connData;
}
@Override
public void onMessage(WebSocket conn, ByteBuffer blob) {
try {
// for handle binary data
blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian
// set to float
FloatBuffer floatbuf = blob.asFloatBuffer();
if (floatbuf.capacity() > 0) {
// allocate memory for float data
float[] arr = new float[floatbuf.capacity()];
floatbuf.get(arr);
ConnectionData connData = creatOrGetConnectionData(conn);
// put websocket to stream queue with binary type==1
connData.addSamplesToData(arr);
if (!streamQueueFind(conn)) {
streamQueue.put(conn);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
public boolean streamQueueFind(WebSocket conn) {
return streamQueue.contains(conn);
}
public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) {
try {
rcgOjb = new OnlineRecognizer(cfgPath);
// size of stream thread pool
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
// size of decoder thread pool
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
// time(ms) idle for decoder thread when no job
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
// size of streams for parallel decoding
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
// time(ms) out for connection data
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
// create stream threads
for (int i = 0; i < streamThreadNum; i++) {
new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start();
}
// create decoder threads
for (int i = 0; i < decoderThreadNum; i++) {
new DecoderThreadHandler(
decoderQueue,
connectionMap,
rcgOjb,
decoderTimeIdle,
parallelDecoderNum,
deocderTimeOut)
.start();
}
} catch (Exception e) {
System.err.println(e);
e.printStackTrace();
}
}
public static Map<String, String> readProperties(String CfgPath) {
// read and parse config file
Properties props = new Properties();
Map<String, String> proMap = new HashMap<String, String>();
try {
File file = new File(CfgPath);
if (!file.exists()) {
logger.info(String.valueOf(CfgPath) + " cfg file not exists!");
System.exit(0);
}
InputStream in = new BufferedInputStream(new FileInputStream(CfgPath));
props.load(in);
Enumeration en = props.propertyNames();
while (en.hasMoreElements()) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
}
} catch (Exception e) {
e.printStackTrace();
}
return proMap;
}
public static void main(String[] args) throws InterruptedException, IOException {
if (args.length != 2) {
logger.info("usage: AsrWebsocketServer soPath modelCfgPath");
return;
}
String soPath = args[0];
String cfgPath = args[1];
OnlineRecognizer.setSoPath(soPath);
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
int port = Integer.valueOf(cfgMap.get("port"));
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
s.initModelWithCfg(cfgMap, cfgPath);
logger.info("Server started on port: " + s.getPort());
s.start();
}
@Override
public void onError(WebSocket conn, Exception ex) {
ex.printStackTrace();
if (conn != null) {
// some errors like port binding failed may not be assignable to a specific websocket
}
}
@Override
public void onStart() {
logger.info("Server started!");
setConnectionLostTimeout(0);
setConnectionLostTimeout(100);
}
}

View File

@@ -0,0 +1,65 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
// connection data act as a bridge between different threads pools
package websocketsrv;
import com.k2fsa.sherpa.onnx.OnlineStream;
import java.time.LocalDateTime;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.*;
import org.java_websocket.WebSocket;
public class ConnectionData {
private WebSocket webSocket; // the websocket for this connection data
private OnlineStream stream; // connection stream
private Queue<float[]> queueSamples =
new LinkedList<float[]>(); // binary data rec from the client
private boolean eof = false; // connection data is done
private LocalDateTime lastHandleTime; // used for time out in ms
public ConnectionData(WebSocket webSocket, OnlineStream stream) {
this.webSocket = webSocket;
this.stream = stream;
}
public void addSamplesToData(float[] samples) {
this.queueSamples.add(samples);
}
public LocalDateTime getLastHandleTime() {
return this.lastHandleTime;
}
public void setLastHandleTime(LocalDateTime now) {
this.lastHandleTime = now;
}
public boolean getEof() {
return this.eof;
}
public void setEof(boolean eof) {
this.eof = eof;
}
public WebSocket getWebSocket() {
return this.webSocket;
}
public Queue<float[]> getQueueSamples() {
return this.queueSamples;
}
public OnlineStream getStream() {
return this.stream;
}
}

View File

@@ -0,0 +1,173 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
// java DecoderThreadHandler
package websocketsrv;
import com.k2fsa.sherpa.onnx.OnlineRecognizer;
import com.k2fsa.sherpa.onnx.OnlineStream;
import java.nio.*;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.util.*;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.LinkedBlockingQueue;
import org.java_websocket.WebSocket;
import org.java_websocket.drafts.Draft;
import org.java_websocket.framing.Framedata;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class DecoderThreadHandler extends Thread {
private static final Logger logger = LoggerFactory.getLogger(DecoderThreadHandler.class);
// Websocket Queue that waiting for decoding
private LinkedBlockingQueue<WebSocket> decoderQueue;
// the mapping between websocket and connection data
private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
private OnlineRecognizer rcgOjb = null; // recgnizer object
// connection data list for this thread to decode in parallel
private List<ConnectionData> connDataList = new ArrayList<ConnectionData>();
private int parallelDecoderNum = 10; // parallel decoding number
private int deocderTimeIdle = 10; // idle time(ms) when no job
private int deocderTimeOut = 3000; // if it is timeout(ms), the connection data will be removed
public DecoderThreadHandler(
LinkedBlockingQueue<WebSocket> decoderQueue,
ConcurrentHashMap<WebSocket, ConnectionData> connMap,
OnlineRecognizer rcgOjb,
int deocderTimeIdle,
int parallelDecoderNum,
int deocderTimeOut) {
this.decoderQueue = decoderQueue;
this.connMap = connMap;
this.rcgOjb = rcgOjb;
this.deocderTimeIdle = deocderTimeIdle;
this.parallelDecoderNum = parallelDecoderNum;
this.deocderTimeOut = deocderTimeOut;
}
public void run() {
while (true) {
try {
// time(ms) idle if there is no job
Thread.sleep(deocderTimeIdle);
// clear data list for this threads
connDataList.clear();
if (rcgOjb == null) continue;
// loop for total decoder Queue
while (!decoderQueue.isEmpty()) {
// get websocket
WebSocket conn = decoderQueue.take();
// get connection data according to websocket
ConnectionData connData = connMap.get(conn);
// if the websocket closed, continue
if (connData == null) continue;
// get the stream
OnlineStream stream = connData.getStream();
// put to decoder list if 1) stream is ready; 2) and
// size not > parallelDecoderNum
if ((rcgOjb.isReady(stream) && connDataList.size() < parallelDecoderNum)) {
// add to this thread's decoder list
connDataList.add(connData);
// change the handled time for this connection data
connData.setLastHandleTime(LocalDateTime.now());
}
// break when decoder list size >= parallelDecoderNum
if (connDataList.size() >= parallelDecoderNum) {
break;
}
}
// if decoder data list for this thread >0
if (connDataList.size() > 0) {
// create a stream array for parallel decoding
OnlineStream[] arr = new OnlineStream[connDataList.size()];
for (int i = 0; i < connDataList.size(); i++) {
arr[i] = connDataList.get(i).getStream();
}
// parallel decoding
rcgOjb.decodeStreams(arr);
}
// get result for each connection
for (ConnectionData connData : connDataList) {
OnlineStream stream = connData.getStream();
WebSocket webSocket = connData.getWebSocket();
String txtResult = rcgOjb.getResult(stream);
// decode text in utf-8
byte[] utf8Data = txtResult.getBytes(StandardCharsets.UTF_8);
boolean isEof = (connData.getEof() == true && !rcgOjb.isReady(stream));
// result
if (utf8Data.length > 0) {
String jsonResult =
"{\"text\":\"" + txtResult + "\",\"eof\":" + String.valueOf(isEof) + "\"}";
if (webSocket.isOpen()) {
// create a TEXT Frame for send back json result
Draft draft = webSocket.getDraft();
List<Framedata> frames = null;
frames = draft.createFrames(jsonResult, false);
// send to client
webSocket.sendFrame(frames);
}
}
}
// loop for each connection data in this thread
for (ConnectionData connData : connDataList) {
OnlineStream stream = connData.getStream();
WebSocket webSocket = connData.getWebSocket();
// if the stream is still ready, put it to decoder Queue again for next decoding
if (rcgOjb.isReady(stream)) {
decoderQueue.put(webSocket);
}
// the duration between last handled time and now
java.time.Duration duration =
java.time.Duration.between(connData.getLastHandleTime(), LocalDateTime.now());
// close the websocket if 1) data is done and stream not ready; 2) or data is time out;
// 3) or
// connection is closed
if ((connData.getEof() == true
&& !rcgOjb.isReady(stream)
&& connData.getQueueSamples().isEmpty())
|| duration.toMillis() > deocderTimeOut
|| !connData.getWebSocket().isOpen()) {
logger.info("close websocket!!!");
// delay close web socket as data may still in processing
Timer timer = new Timer();
timer.schedule(
new TimerTask() {
public void run() {
webSocket.close();
}
},
5000); // 5 seconds
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}

View File

@@ -0,0 +1,67 @@
/*
* // Copyright 2022-2023 by zhaoming
*/
// java StreamThreadHandler
package websocketsrv;
import com.k2fsa.sherpa.onnx.OnlineStream;
import java.nio.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.LinkedBlockingQueue;
import org.java_websocket.WebSocket;
// thread for processing stream
public class StreamThreadHandler extends Thread {
// Queue between io network io thread pool and stream thread pool, use websocket as the key
private LinkedBlockingQueue<WebSocket> streamQueue;
// Queue waiting for deocdeing, use websocket as the key
private LinkedBlockingQueue<WebSocket> decoderQueue;
// mapping between websocket connection and connection data
private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
public StreamThreadHandler(
LinkedBlockingQueue<WebSocket> streamQueue,
LinkedBlockingQueue<WebSocket> decoderQueue,
ConcurrentHashMap<WebSocket, ConnectionData> connMap) {
this.streamQueue = streamQueue;
this.decoderQueue = decoderQueue;
this.connMap = connMap;
}
public void run() {
while (true) {
try {
// fetch one websocket from queue
WebSocket conn = (WebSocket) this.streamQueue.take();
// get the connection data according to websocket
ConnectionData connData = connMap.get(conn);
OnlineStream stream = connData.getStream();
// handle received binary data
if (!connData.getQueueSamples().isEmpty()) {
// loop to put all received binary data to stream
while (!connData.getQueueSamples().isEmpty()) {
float[] samples = connData.getQueueSamples().poll();
stream.acceptWaveform(samples);
}
// if data is finished
if (connData.getEof() == true) {
stream.inputFinished();
}
// add this websocket to decoder Queue if not in the Queue
if (!decoderQueue.contains(conn)) {
decoderQueue.put(conn);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}