Add non-streaming websocket server for python (#259)

This commit is contained in:
Fangjun Kuang
2023-08-11 15:56:24 +08:00
committed by GitHub
parent 6c0f002825
commit b094868fb8
24 changed files with 1247 additions and 92 deletions

View File

@@ -13,11 +13,37 @@ Usage:
Example:
(1) Without a certificate
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
(2) With a certificate
(a) Generate a certificate first:
cd python-api-examples/web
./generate-certificate.py
cd ../..
(b) Start the server
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--certificate ./python-api-examples/web/cert.pem
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
"""
import argparse
@@ -35,6 +61,7 @@ from typing import List, Optional, Tuple
import numpy as np
import sherpa_onnx
import websockets
from http_server import HttpServer
@@ -269,8 +296,8 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.",
default=2,
help="Number of threads to run the neural network model",
)
parser.add_argument(
@@ -278,8 +305,10 @@ def get_args():
type=str,
help="""Path to the X.509 certificate. You need it only if you want to
use a secure websocket connection, i.e., use wss:// instead of ws://.
You can use sherpa/bin/web/generate-certificate.py
You can use ./web/generate-certificate.py
to generate the certificate `cert.pem`.
Note ./web/generate-certificate.py will generate three files but you
only need to pass the generated cert.pem to this option.
""",
)
@@ -287,7 +316,7 @@ def get_args():
"--doc-root",
type=str,
default="./python-api-examples/web",
help="""Path to the web root""",
help="Path to the web root",
)
return parser.parse_args()
@@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
encoder=args.encoder_model,
decoder=args.decoder_model,
joiner=args.joiner_model,
num_threads=1,
sample_rate=16000,
feature_dim=80,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.num_active_paths,
enable_endpoint_detection=args.use_endpoint != 0,
@@ -359,7 +388,7 @@ class StreamingServer(object):
server locate.
certificate:
Optional. If not None, it will use secure websocket.
You can use ./sherpa/bin/web/generate-certificate.py to generate
You can use ./web/generate-certificate.py to generate
it (the default generated filename is `cert.pem`).
"""
self.recognizer = recognizer
@@ -373,6 +402,7 @@ class StreamingServer(object):
)
self.stream_queue = asyncio.Queue()
self.max_wait_ms = max_wait_ms
self.max_batch_size = max_batch_size
self.max_message_size = max_message_size
@@ -382,11 +412,10 @@ class StreamingServer(object):
self.current_active_connections = 0
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
self.decoding_method = recognizer.config.decoding_method
async def stream_consumer_task(self):
"""This function extracts streams from the queue, batches them up, sends
them to the RNN-T model for computation and decoding.
them to the neural network model for computation and decoding.
"""
while True:
if self.stream_queue.empty():
@@ -442,7 +471,22 @@ class StreamingServer(object):
# This is a normal HTTP request
if path == "/":
path = "/index.html"
found, response, mime_type = self.http_server.process_request(path)
if path in ("/upload.html", "/offline_record.html"):
response = r"""
<!doctype html><html><head>
<title>Speech recognition with next-gen Kaldi</title><body>
<h2>Only /streaming_record.html is available for the streaming server.<h2>
<br/>
<br/>
Go back to <a href="/streaming_record.html">/streaming_record.html</a>
</body></head></html>
"""
found = True
mime_type = "text/html"
else:
found, response, mime_type = self.http_server.process_request(path)
if isinstance(response, str):
response = response.encode("utf-8")
@@ -484,12 +528,21 @@ class StreamingServer(object):
process_request=self.process_request,
ssl=ssl_context,
):
ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]
ip_list.append(socket.gethostbyname(socket.gethostname()))
ip_list = ["localhost"]
if ssl_context:
ip_list += ["0.0.0.0", "127.0.0.1"]
ip_list.append(socket.gethostbyname(socket.gethostname()))
proto = "http://" if ssl_context is None else "https://"
s = "Please visit one of the following addresses:\n\n"
for p in ip_list:
s += " " + proto + p + f":{port}" "\n"
if not ssl_context:
s += "\nSince you are not providing a certificate, you cannot "
s += "use your microphone from within the browser using "
s += "public IP addresses. Only localhost can be used."
s += "You also cannot use 0.0.0.0 or 127.0.0.1"
logging.info(s)
await asyncio.Future() # run forever
@@ -525,7 +578,7 @@ class StreamingServer(object):
socket: websockets.WebSocketServerProtocol,
):
"""Receive audio samples from the client, process it, and send
deocoding result back to the client.
decoding result back to the client.
Args:
socket:
@@ -560,8 +613,6 @@ class StreamingServer(object):
self.recognizer.reset(stream)
segment += 1
print(message)
await socket.send(json.dumps(message))
tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
@@ -583,7 +634,7 @@ class StreamingServer(object):
self,
socket: websockets.WebSocketServerProtocol,
) -> Optional[np.ndarray]:
"""Receives a tensor from the client.
"""Receive a tensor from the client.
Each message contains either a bytes buffer containing audio samples
in 16 kHz or contains "Done" meaning the end of utterance.
@@ -660,6 +711,6 @@ def main():
if __name__ == "__main__":
log_filename = "log/log-streaming-zipformer"
log_filename = "log/log-streaming-server"
setup_logger(log_filename)
main()