Add non-streaming websocket server for python (#259)
This commit is contained in:
@@ -13,11 +13,37 @@ Usage:
|
||||
|
||||
Example:
|
||||
|
||||
(1) Without a certificate
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
||||
|
||||
(2) With a certificate
|
||||
|
||||
(a) Generate a certificate first:
|
||||
|
||||
cd python-api-examples/web
|
||||
./generate-certificate.py
|
||||
cd ../..
|
||||
|
||||
(b) Start the server
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
|
||||
--certificate ./python-api-examples/web/cert.pem
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
|
||||
to download pre-trained models.
|
||||
|
||||
The model in the above help messages is from
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -35,6 +61,7 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import websockets
|
||||
|
||||
from http_server import HttpServer
|
||||
|
||||
|
||||
@@ -269,8 +296,8 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.",
|
||||
default=2,
|
||||
help="Number of threads to run the neural network model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -278,8 +305,10 @@ def get_args():
|
||||
type=str,
|
||||
help="""Path to the X.509 certificate. You need it only if you want to
|
||||
use a secure websocket connection, i.e., use wss:// instead of ws://.
|
||||
You can use sherpa/bin/web/generate-certificate.py
|
||||
You can use ./web/generate-certificate.py
|
||||
to generate the certificate `cert.pem`.
|
||||
Note ./web/generate-certificate.py will generate three files but you
|
||||
only need to pass the generated cert.pem to this option.
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -287,7 +316,7 @@ def get_args():
|
||||
"--doc-root",
|
||||
type=str,
|
||||
default="./python-api-examples/web",
|
||||
help="""Path to the web root""",
|
||||
help="Path to the web root",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
||||
encoder=args.encoder_model,
|
||||
decoder=args.decoder_model,
|
||||
joiner=args.joiner_model,
|
||||
num_threads=1,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
max_active_paths=args.num_active_paths,
|
||||
enable_endpoint_detection=args.use_endpoint != 0,
|
||||
@@ -359,7 +388,7 @@ class StreamingServer(object):
|
||||
server locate.
|
||||
certificate:
|
||||
Optional. If not None, it will use secure websocket.
|
||||
You can use ./sherpa/bin/web/generate-certificate.py to generate
|
||||
You can use ./web/generate-certificate.py to generate
|
||||
it (the default generated filename is `cert.pem`).
|
||||
"""
|
||||
self.recognizer = recognizer
|
||||
@@ -373,6 +402,7 @@ class StreamingServer(object):
|
||||
)
|
||||
|
||||
self.stream_queue = asyncio.Queue()
|
||||
|
||||
self.max_wait_ms = max_wait_ms
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_message_size = max_message_size
|
||||
@@ -382,11 +412,10 @@ class StreamingServer(object):
|
||||
self.current_active_connections = 0
|
||||
|
||||
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
|
||||
self.decoding_method = recognizer.config.decoding_method
|
||||
|
||||
async def stream_consumer_task(self):
|
||||
"""This function extracts streams from the queue, batches them up, sends
|
||||
them to the RNN-T model for computation and decoding.
|
||||
them to the neural network model for computation and decoding.
|
||||
"""
|
||||
while True:
|
||||
if self.stream_queue.empty():
|
||||
@@ -442,7 +471,22 @@ class StreamingServer(object):
|
||||
# This is a normal HTTP request
|
||||
if path == "/":
|
||||
path = "/index.html"
|
||||
found, response, mime_type = self.http_server.process_request(path)
|
||||
|
||||
if path in ("/upload.html", "/offline_record.html"):
|
||||
response = r"""
|
||||
<!doctype html><html><head>
|
||||
<title>Speech recognition with next-gen Kaldi</title><body>
|
||||
<h2>Only /streaming_record.html is available for the streaming server.<h2>
|
||||
<br/>
|
||||
<br/>
|
||||
Go back to <a href="/streaming_record.html">/streaming_record.html</a>
|
||||
</body></head></html>
|
||||
"""
|
||||
found = True
|
||||
mime_type = "text/html"
|
||||
else:
|
||||
found, response, mime_type = self.http_server.process_request(path)
|
||||
|
||||
if isinstance(response, str):
|
||||
response = response.encode("utf-8")
|
||||
|
||||
@@ -484,12 +528,21 @@ class StreamingServer(object):
|
||||
process_request=self.process_request,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]
|
||||
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
||||
ip_list = ["localhost"]
|
||||
if ssl_context:
|
||||
ip_list += ["0.0.0.0", "127.0.0.1"]
|
||||
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
||||
proto = "http://" if ssl_context is None else "https://"
|
||||
s = "Please visit one of the following addresses:\n\n"
|
||||
for p in ip_list:
|
||||
s += " " + proto + p + f":{port}" "\n"
|
||||
|
||||
if not ssl_context:
|
||||
s += "\nSince you are not providing a certificate, you cannot "
|
||||
s += "use your microphone from within the browser using "
|
||||
s += "public IP addresses. Only localhost can be used."
|
||||
s += "You also cannot use 0.0.0.0 or 127.0.0.1"
|
||||
|
||||
logging.info(s)
|
||||
|
||||
await asyncio.Future() # run forever
|
||||
@@ -525,7 +578,7 @@ class StreamingServer(object):
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
):
|
||||
"""Receive audio samples from the client, process it, and send
|
||||
deocoding result back to the client.
|
||||
decoding result back to the client.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
@@ -560,8 +613,6 @@ class StreamingServer(object):
|
||||
self.recognizer.reset(stream)
|
||||
segment += 1
|
||||
|
||||
print(message)
|
||||
|
||||
await socket.send(json.dumps(message))
|
||||
|
||||
tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
|
||||
@@ -583,7 +634,7 @@ class StreamingServer(object):
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Receives a tensor from the client.
|
||||
"""Receive a tensor from the client.
|
||||
|
||||
Each message contains either a bytes buffer containing audio samples
|
||||
in 16 kHz or contains "Done" meaning the end of utterance.
|
||||
@@ -660,6 +711,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log_filename = "log/log-streaming-zipformer"
|
||||
log_filename = "log/log-streaming-server"
|
||||
setup_logger(log_filename)
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user