add offline websocket server/client (#98)
This commit is contained in:
2
.flake8
2
.flake8
@@ -1,7 +1,7 @@
|
||||
[flake8]
|
||||
show-source=true
|
||||
statistics=true
|
||||
max-line-length = 80
|
||||
max-line-length = 120
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
|
||||
7
.github/scripts/test-python.sh
vendored
7
.github/scripts/test-python.sh
vendored
@@ -30,9 +30,12 @@ ls -lh
|
||||
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/decode-file.py \
|
||||
python3 ./python-api-examples/online-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--wave-filename=$repo/test_wavs/4.wav
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/3.wav
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example
|
||||
run-sherpa-onnx-offline-paraformer.sh
|
||||
run-sherpa-onnx-offline-transducer.sh
|
||||
sherpa-onnx-paraformer-zh-2023-03-28
|
||||
run-offline-websocket-server-paraformer.sh
|
||||
|
||||
158
python-api-examples/offline-websocket-client-decode-files-paralell.py
Executable file
158
python-api-examples/offline-websocket-client-decode-files-paralell.py
Executable file
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
A websocket client for sherpa-onnx-offline-websocket-server
|
||||
|
||||
This file shows how to transcribe multiple
|
||||
files in parallel. We create a separate connection for transcribing each file.
|
||||
|
||||
Usage:
|
||||
./offline-websocket-client-decode-files-parallel.py \
|
||||
--server-addr localhost \
|
||||
--server-port 6006 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
/path/to/16kHz.wav \
|
||||
/path/to/8kHz.wav
|
||||
|
||||
(Note: You have to first start the server before starting the client)
|
||||
|
||||
You can find the server at
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
|
||||
|
||||
Note: The server is implemented in C++.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import wave
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("please run:")
|
||||
print("")
|
||||
print(" pip install websockets")
|
||||
print("")
|
||||
print("before you run this script")
|
||||
print("")
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=6006,
|
||||
help="Port of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
async def run(
|
||||
server_addr: str,
|
||||
server_port: int,
|
||||
wave_filename: str,
|
||||
):
|
||||
async with websockets.connect(
|
||||
f"ws://{server_addr}:{server_port}"
|
||||
) as websocket: # noqa
|
||||
logging.info(f"Sending {wave_filename}")
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
assert isinstance(sample_rate, int)
|
||||
assert samples.dtype == np.float32, samples.dtype
|
||||
assert samples.ndim == 1, samples.dim
|
||||
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
|
||||
buf += (samples.size * 4).to_bytes(4, byteorder="little")
|
||||
buf += samples.tobytes()
|
||||
|
||||
await websocket.send(buf)
|
||||
|
||||
decoding_results = await websocket.recv()
|
||||
logging.info(f"{wave_filename}\n{decoding_results}")
|
||||
|
||||
# to signal that the client has sent all the data
|
||||
await websocket.send("Done")
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
server_addr = args.server_addr
|
||||
server_port = args.server_port
|
||||
sound_files = args.sound_files
|
||||
|
||||
all_tasks = []
|
||||
for wave_filename in sound_files:
|
||||
task = asyncio.create_task(
|
||||
run(
|
||||
server_addr=server_addr,
|
||||
server_port=server_port,
|
||||
wave_filename=wave_filename,
|
||||
)
|
||||
)
|
||||
all_tasks.append(task)
|
||||
|
||||
await asyncio.gather(*all_tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
asyncio.run(main())
|
||||
152
python-api-examples/offline-websocket-client-decode-files-sequential.py
Executable file
152
python-api-examples/offline-websocket-client-decode-files-sequential.py
Executable file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
A websocket client for sherpa-onnx-offline-websocket-server
|
||||
|
||||
This file shows how to use a single connection to transcribe multiple
|
||||
files sequentially.
|
||||
|
||||
Usage:
|
||||
./offline-websocket-client-decode-files-sequential.py \
|
||||
--server-addr localhost \
|
||||
--server-port 6006 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
/path/to/16kHz.wav \
|
||||
/path/to/8kHz.wav
|
||||
|
||||
(Note: You have to first start the server before starting the client)
|
||||
|
||||
You can find the server at
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
|
||||
|
||||
Note: The server is implemented in C++.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import wave
|
||||
from typing import List, Tuple
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("please run:")
|
||||
print("")
|
||||
print(" pip install websockets")
|
||||
print("")
|
||||
print("before you run this script")
|
||||
print("")
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=6006,
|
||||
help="Port of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
async def run(
|
||||
server_addr: str,
|
||||
server_port: int,
|
||||
sound_files: List[str],
|
||||
):
|
||||
async with websockets.connect(
|
||||
f"ws://{server_addr}:{server_port}"
|
||||
) as websocket: # noqa
|
||||
for wave_filename in sound_files:
|
||||
logging.info(f"Sending {wave_filename}")
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
assert isinstance(sample_rate, int)
|
||||
assert samples.dtype == np.float32, samples.dtype
|
||||
assert samples.ndim == 1, samples.dim
|
||||
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
|
||||
buf += (samples.size * 4).to_bytes(4, byteorder="little")
|
||||
buf += samples.tobytes()
|
||||
|
||||
await websocket.send(buf)
|
||||
|
||||
decoding_results = await websocket.recv()
|
||||
print(decoding_results)
|
||||
|
||||
# to signal that the client has sent all the data
|
||||
await websocket.send("Done")
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
server_addr = args.server_addr
|
||||
server_port = args.server_port
|
||||
sound_files = args.sound_files
|
||||
|
||||
await run(
|
||||
server_addr=server_addr,
|
||||
server_port=server_port,
|
||||
sound_files=sound_files,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
asyncio.run(main())
|
||||
@@ -1,8 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to recognize
|
||||
a single file.
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a streaming model.
|
||||
|
||||
Usage:
|
||||
./online-decode-files.py \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
/path/to/16kHz.wav \
|
||||
/path/to/8kHz.wav
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
@@ -13,17 +20,12 @@ import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(
|
||||
filename
|
||||
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@@ -68,26 +70,58 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wave-filename",
|
||||
"sound_files",
|
||||
type=str,
|
||||
help="""Path to the wave filename.
|
||||
Should have a single channel with 16-bit samples.
|
||||
It does not need to be 16kHz. It can have any sampling rate.
|
||||
""",
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
assert_file_exists(args.tokens)
|
||||
if not Path(args.wave_filename).is_file():
|
||||
print(f"{args.wave_filename} does not exist!")
|
||||
return
|
||||
|
||||
recognizer = sherpa_onnx.OnlineRecognizer(
|
||||
tokens=args.tokens,
|
||||
@@ -99,42 +133,44 @@ def main():
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
with wave.open(args.wave_filename) as f:
|
||||
# If the wave file has a different sampling rate from the one
|
||||
# expected by the model (16 kHz in our case), we will do
|
||||
# resampling inside sherpa-onnx
|
||||
wave_file_sample_rate = f.getframerate()
|
||||
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
|
||||
duration = len(samples_float32) / wave_file_sample_rate
|
||||
|
||||
start_time = time.time()
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
streams = []
|
||||
total_duration = 0
|
||||
for wave_filename in args.sound_files:
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
|
||||
stream.accept_waveform(wave_file_sample_rate, samples_float32)
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(wave_file_sample_rate, tail_paddings)
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
|
||||
stream.input_finished()
|
||||
s.input_finished()
|
||||
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
streams.append(s)
|
||||
|
||||
print(recognizer.get_result(stream))
|
||||
|
||||
print("Done!")
|
||||
while True:
|
||||
ready_list = []
|
||||
for s in streams:
|
||||
if recognizer.is_ready(s):
|
||||
ready_list.append(s)
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
recognizer.decode_streams(ready_list)
|
||||
results = [recognizer.get_result(s) for s in streams]
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
@@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import wave
|
||||
|
||||
try:
|
||||
|
||||
5
python-api-examples/online-websocket-client-microphone.py
Normal file → Executable file
5
python-api-examples/online-websocket-client-microphone.py
Normal file → Executable file
@@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
@@ -134,7 +133,7 @@ async def run(
|
||||
await websocket.send(indata.tobytes())
|
||||
|
||||
decoding_results = await receive_task
|
||||
print("\nFinal result is:\n{decoding_results}")
|
||||
print(f"\nFinal result is:\n{decoding_results}")
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -13,7 +13,7 @@ from pathlib import Path
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
@@ -25,9 +25,11 @@ import sherpa_onnx
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(
|
||||
filename
|
||||
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
|
||||
@@ -12,7 +12,7 @@ from pathlib import Path
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
@@ -24,9 +24,11 @@ import sherpa_onnx
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(
|
||||
filename
|
||||
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
|
||||
@@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
|
||||
|
||||
|
||||
add_executable(sherpa-onnx-online-websocket-client
|
||||
online-websocket-client.cc
|
||||
)
|
||||
@@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
|
||||
endif()
|
||||
|
||||
# For offline websocket
|
||||
add_executable(sherpa-onnx-offline-websocket-server
|
||||
offline-websocket-server-impl.cc
|
||||
offline-websocket-server.cc
|
||||
)
|
||||
target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core)
|
||||
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx-offline-websocket-server -pthread)
|
||||
target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
285
sherpa-onnx/csrc/offline-websocket-server-impl.cc
Normal file
285
sherpa-onnx/csrc/offline-websocket-server-impl.cc
Normal file
@@ -0,0 +1,285 @@
|
||||
// sherpa-onnx/csrc/offline-websocket-server-impl.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) {
|
||||
recognizer_config.Register(po);
|
||||
|
||||
po->Register("max-batch-size", &max_batch_size,
|
||||
"Max batch size for decoding.");
|
||||
|
||||
po->Register(
|
||||
"max-utterance-length", &max_utterance_length,
|
||||
"Max utterance length in seconds. If we receive an utterance "
|
||||
"longer than this value, we will reject the connection. "
|
||||
"If you have enough memory, you can select a large value for it.");
|
||||
}
|
||||
|
||||
void OfflineWebsocketDecoderConfig::Validate() const {
|
||||
if (!recognizer_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Error in recongizer config");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (max_batch_size <= 0) {
|
||||
SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (max_utterance_length <= 0) {
|
||||
SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f",
|
||||
max_utterance_length);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server)
|
||||
: config_(server->GetConfig().decoder_config),
|
||||
server_(server),
|
||||
recognizer_(config_.recognizer_config) {}
|
||||
|
||||
void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
streams_.push_back({hdl, d});
|
||||
}
|
||||
|
||||
void OfflineWebsocketDecoder::Decode() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (streams_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t size =
|
||||
std::min(static_cast<int32_t>(streams_.size()), config_.max_batch_size);
|
||||
SHERPA_ONNX_LOGE("size: %d", size);
|
||||
|
||||
// We first lock the mutex for streams_, take items from it, and then
|
||||
// unlock the mutex; in doing so we don't need to lock the mutex to
|
||||
// access hdl and connection_data later.
|
||||
std::vector<connection_hdl> handles(size);
|
||||
|
||||
// Store connection_data here to prevent the data from being freed
|
||||
// while we are still using it.
|
||||
std::vector<ConnectionDataPtr> connection_data(size);
|
||||
|
||||
std::vector<const float *> samples(size);
|
||||
std::vector<int32_t> samples_length(size);
|
||||
std::vector<std::unique_ptr<OfflineStream>> ss(size);
|
||||
std::vector<OfflineStream *> p_ss(size);
|
||||
|
||||
for (int32_t i = 0; i != size; ++i) {
|
||||
auto &p = streams_.front();
|
||||
handles[i] = p.first;
|
||||
connection_data[i] = p.second;
|
||||
streams_.pop_front();
|
||||
|
||||
auto sample_rate = connection_data[i]->sample_rate;
|
||||
auto samples =
|
||||
reinterpret_cast<const float *>(&connection_data[i]->data[0]);
|
||||
auto num_samples = connection_data[i]->expected_byte_size / sizeof(float);
|
||||
auto s = recognizer_.CreateStream();
|
||||
s->AcceptWaveform(sample_rate, samples, num_samples);
|
||||
|
||||
ss[i] = std::move(s);
|
||||
p_ss[i] = ss[i].get();
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
|
||||
// Note: DecodeStreams is thread-safe
|
||||
recognizer_.DecodeStreams(p_ss.data(), size);
|
||||
|
||||
for (int32_t i = 0; i != size; ++i) {
|
||||
connection_hdl hdl = handles[i];
|
||||
asio::post(server_->GetConnectionContext(),
|
||||
[this, hdl, text = ss[i]->GetResult().text]() {
|
||||
websocketpp::lib::error_code ec;
|
||||
server_->GetServer().send(
|
||||
hdl, text, websocketpp::frame::opcode::text, ec);
|
||||
if (ec) {
|
||||
server_->GetServer().get_alog().write(
|
||||
websocketpp::log::alevel::app, ec.message());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void OfflineWebsocketServerConfig::Register(ParseOptions *po) {
|
||||
decoder_config.Register(po);
|
||||
po->Register("log-file", &log_file,
|
||||
"Path to the log file. Logs are "
|
||||
"appended to this file");
|
||||
}
|
||||
|
||||
void OfflineWebsocketServerConfig::Validate() const {
|
||||
decoder_config.Validate();
|
||||
}
|
||||
|
||||
OfflineWebsocketServer::OfflineWebsocketServer(
|
||||
asio::io_context &io_conn, // NOLINT
|
||||
asio::io_context &io_work, // NOLINT
|
||||
const OfflineWebsocketServerConfig &config)
|
||||
: io_conn_(io_conn),
|
||||
io_work_(io_work),
|
||||
config_(config),
|
||||
log_(config.log_file, std::ios::app),
|
||||
tee_(std::cout, log_),
|
||||
decoder_(this) {
|
||||
SetupLog();
|
||||
|
||||
server_.init_asio(&io_conn_);
|
||||
|
||||
server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
|
||||
|
||||
server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
|
||||
|
||||
server_.set_message_handler(
|
||||
[this](connection_hdl hdl, server::message_ptr msg) {
|
||||
OnMessage(hdl, msg);
|
||||
});
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::SetupLog() {
|
||||
server_.clear_access_channels(websocketpp::log::alevel::all);
|
||||
server_.set_access_channels(websocketpp::log::alevel::connect);
|
||||
server_.set_access_channels(websocketpp::log::alevel::disconnect);
|
||||
|
||||
// So that it also prints to std::cout and std::cerr
|
||||
server_.get_alog().set_ostream(&tee_);
|
||||
server_.get_elog().set_ostream(&tee_);
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::OnOpen(connection_hdl hdl) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connections_.emplace(hdl, std::make_shared<ConnectionData>());
|
||||
|
||||
SHERPA_ONNX_LOGE("Number of active connections: %d",
|
||||
static_cast<int32_t>(connections_.size()));
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::OnClose(connection_hdl hdl) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connections_.erase(hdl);
|
||||
|
||||
SHERPA_ONNX_LOGE("Number of active connections: %d",
|
||||
static_cast<int32_t>(connections_.size()));
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::OnMessage(connection_hdl hdl,
|
||||
server::message_ptr msg) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto connection_data = connections_.find(hdl)->second;
|
||||
lock.unlock();
|
||||
const std::string &payload = msg->get_payload();
|
||||
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text:
|
||||
if (payload == "Done") {
|
||||
// The client will not send any more data. We can close the
|
||||
// connection now.
|
||||
Close(hdl, websocketpp::close::status::normal, "Done");
|
||||
} else {
|
||||
Close(hdl, websocketpp::close::status::normal,
|
||||
std::string("Invalid payload: ") + payload);
|
||||
}
|
||||
break;
|
||||
|
||||
case websocketpp::frame::opcode::binary: {
|
||||
auto p = reinterpret_cast<const int8_t *>(payload.data());
|
||||
|
||||
if (connection_data->expected_byte_size == 0) {
|
||||
if (payload.size() < 8) {
|
||||
Close(hdl, websocketpp::close::status::normal,
|
||||
"Payload is too short");
|
||||
break;
|
||||
}
|
||||
|
||||
connection_data->sample_rate = *reinterpret_cast<const int32_t *>(p);
|
||||
|
||||
connection_data->expected_byte_size =
|
||||
*reinterpret_cast<const int32_t *>(p + 4);
|
||||
|
||||
int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length *
|
||||
connection_data->sample_rate * sizeof(float);
|
||||
if (connection_data->expected_byte_size > max_byte_size_) {
|
||||
float num_samples =
|
||||
connection_data->expected_byte_size / sizeof(float);
|
||||
|
||||
float duration = num_samples / connection_data->sample_rate;
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Max utterance length is configured to "
|
||||
<< decoder_.GetConfig().max_utterance_length
|
||||
<< " seconds, received length is " << duration << " seconds. "
|
||||
<< "Payload is too large!";
|
||||
Close(hdl, websocketpp::close::status::message_too_big, os.str());
|
||||
break;
|
||||
}
|
||||
|
||||
connection_data->data.resize(connection_data->expected_byte_size);
|
||||
std::copy(payload.begin() + 8, payload.end(),
|
||||
connection_data->data.data());
|
||||
connection_data->cur = payload.size() - 8;
|
||||
} else {
|
||||
std::copy(payload.begin(), payload.end(),
|
||||
connection_data->data.data() + connection_data->cur);
|
||||
connection_data->cur += payload.size();
|
||||
}
|
||||
|
||||
if (connection_data->expected_byte_size == connection_data->cur) {
|
||||
auto d = std::make_shared<ConnectionData>(std::move(*connection_data));
|
||||
// Clear it so that we can handle the next audio file from the client.
|
||||
// The client can send multiple audio files for recognition without
|
||||
// the need to create another connection.
|
||||
connection_data->sample_rate = 0;
|
||||
connection_data->expected_byte_size = 0;
|
||||
connection_data->cur = 0;
|
||||
|
||||
decoder_.Push(hdl, d);
|
||||
|
||||
connection_data->Clear();
|
||||
|
||||
asio::post(io_work_, [this]() { decoder_.Decode(); });
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
// Unexpected message, ignore it
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::Close(connection_hdl hdl,
|
||||
websocketpp::close::status::value code,
|
||||
const std::string &reason) {
|
||||
auto con = server_.get_con_from_hdl(hdl);
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
|
||||
<< "\n";
|
||||
|
||||
websocketpp::lib::error_code ec;
|
||||
server_.close(hdl, code, reason, ec);
|
||||
if (ec) {
|
||||
os << "Failed to close" << con->get_remote_endpoint() << ". "
|
||||
<< ec.message() << "\n";
|
||||
}
|
||||
server_.get_alog().write(websocketpp::log::alevel::app, os.str());
|
||||
}
|
||||
|
||||
void OfflineWebsocketServer::Run(uint16_t port) {
|
||||
server_.set_reuse_addr(true);
|
||||
server_.listen(asio::ip::tcp::v4(), port);
|
||||
server_.start_accept();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
205
sherpa-onnx/csrc/offline-websocket-server-impl.h
Normal file
205
sherpa-onnx/csrc/offline-websocket-server-impl.h
Normal file
@@ -0,0 +1,205 @@
|
||||
// sherpa-onnx/csrc/offline-websocket-server-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
|
||||
#include <deque>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/tee-stream.h"
|
||||
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
|
||||
#include "websocketpp/server.hpp"
|
||||
|
||||
using server = websocketpp::server<websocketpp::config::asio>;
|
||||
using connection_hdl = websocketpp::connection_hdl;
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Communication protocol
|
||||
*
|
||||
* The client sends a byte stream to the server. The first 4 bytes in little
|
||||
* endian indicates the sample rate of the audio data that the client will send.
|
||||
* The next 4 bytes in little endian indicates the total samples in bytes the
|
||||
* client will send. The remaining bytes represent audio samples. Each audio
|
||||
* sample is a float occupying 4 bytes and is normalized into the range
|
||||
* [-1, 1].
|
||||
*
|
||||
* The byte stream can be broken into arbitrary number of messages.
|
||||
* We require that the first message has to be at least 8 bytes so that
|
||||
* we can get `sample_rate` and `expected_byte_size` from the first message.
|
||||
*/
|
||||
struct ConnectionData {
|
||||
// Sample rate of the audio samples the client
|
||||
int32_t sample_rate;
|
||||
|
||||
// Number of expected bytes sent from the client
|
||||
int32_t expected_byte_size = 0;
|
||||
|
||||
// Number of bytes received so far
|
||||
int32_t cur = 0;
|
||||
|
||||
// It saves the received samples from the client.
|
||||
// We will **reinterpret_cast** it to float.
|
||||
// We expect that data.size() == expected_byte_size
|
||||
std::vector<int8_t> data;
|
||||
|
||||
void Clear() {
|
||||
sample_rate = 0;
|
||||
expected_byte_size = 0;
|
||||
cur = 0;
|
||||
data.clear();
|
||||
}
|
||||
};
|
||||
|
||||
using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
|
||||
|
||||
struct OfflineWebsocketDecoderConfig {
|
||||
OfflineRecognizerConfig recognizer_config;
|
||||
|
||||
int32_t max_batch_size = 5;
|
||||
|
||||
float max_utterance_length = 300; // seconds
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
void Validate() const;
|
||||
};
|
||||
|
||||
class OfflineWebsocketServer;
|
||||
|
||||
class OfflineWebsocketDecoder {
|
||||
public:
|
||||
/**
|
||||
* @param config Configuration for the decoder.
|
||||
* @param server **Borrowed** from outside.
|
||||
*/
|
||||
explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
|
||||
|
||||
/** Insert received data to the queue for decoding.
|
||||
*
|
||||
* @param hdl A handle to the connection. We can use it to send the result
|
||||
* back to the client once it finishes decoding.
|
||||
* @param d The received data
|
||||
*/
|
||||
void Push(connection_hdl hdl, ConnectionDataPtr d);
|
||||
|
||||
/** It is called by one of the work thread.
|
||||
*/
|
||||
void Decode();
|
||||
|
||||
const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
|
||||
|
||||
private:
|
||||
OfflineWebsocketDecoderConfig config_;
|
||||
|
||||
/** When we have received all the data from the client, we put it into
|
||||
* this queue; the worker threads will get items from this queue for
|
||||
* decoding.
|
||||
*
|
||||
* Number of items to take from this queue is determined by
|
||||
* `--max-batch-size`. If there are not enough items in the queue, we won't
|
||||
* wait and take whatever we have for decoding.
|
||||
*/
|
||||
std::mutex mutex_;
|
||||
std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
|
||||
|
||||
OfflineWebsocketServer *server_; // Not owned
|
||||
OfflineRecognizer recognizer_;
|
||||
};
|
||||
|
||||
struct OfflineWebsocketServerConfig {
|
||||
OfflineWebsocketDecoderConfig decoder_config;
|
||||
std::string log_file = "./log.txt";
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
void Validate() const;
|
||||
};
|
||||
|
||||
class OfflineWebsocketServer {
|
||||
public:
|
||||
OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
|
||||
asio::io_context &io_work, // NOLINT
|
||||
const OfflineWebsocketServerConfig &config);
|
||||
|
||||
asio::io_context &GetConnectionContext() { return io_conn_; }
|
||||
server &GetServer() { return server_; }
|
||||
|
||||
void Run(uint16_t port);
|
||||
|
||||
const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
|
||||
|
||||
private:
|
||||
void SetupLog();
|
||||
|
||||
// When a websocket client is connected, it will invoke this method
|
||||
// (Not for HTTP)
|
||||
void OnOpen(connection_hdl hdl);
|
||||
|
||||
// When a websocket client is disconnected, it will invoke this method
|
||||
void OnClose(connection_hdl hdl);
|
||||
|
||||
// When a message is received from a websocket client, this method will
|
||||
// be invoked.
|
||||
//
|
||||
// The protocol between the client and the server is as follows:
|
||||
//
|
||||
// (1) The client connects to the server
|
||||
// (2) The client starts to send binary byte stream to the server.
|
||||
// The byte stream can be broken into multiple messages or it can
|
||||
// be put into a single message.
|
||||
// The first message has to contain at least 8 bytes. The first
|
||||
// 4 bytes in little endian contains a int32_t indicating the
|
||||
// sampling rate. The next 4 bytes in little endian contains a int32_t
|
||||
// indicating total number of bytes of samples the client will send.
|
||||
// We assume each sample is a float containing 4 bytes and has been
|
||||
// normalized to the range [-1, 1].
|
||||
// (4) When the server receives all the samples from the client, it will
|
||||
// start to decode them. Once decoded, the server sends a text message
|
||||
// to the client containing the decoded results
|
||||
// (5) After receiving the decoded results from the server, if the client has
|
||||
// another audio file to send, it repeats (2), (3), (4)
|
||||
// (6) If the client has no more audio files to decode, the client sends a
|
||||
// text message containing "Done" to the server and closes the connection
|
||||
// (7) The server receives a text message "Done" and closes the connection
|
||||
//
|
||||
// Note:
|
||||
// (a) All models in icefall use features extracted from audio samples
|
||||
// normalized to the range [-1, 1]. Please send normalized audio samples
|
||||
// if you use models from icefall.
|
||||
// (b) Only sound files with a single channel is supported
|
||||
// (c) Only audio samples are sent. For instance, if we want to decode
|
||||
// a WAVE file, the RIFF header of the WAVE is not sent.
|
||||
void OnMessage(connection_hdl hdl, server::message_ptr msg);
|
||||
|
||||
// Close a websocket connection with given code and reason
|
||||
void Close(connection_hdl hdl, websocketpp::close::status::value code,
|
||||
const std::string &reason);
|
||||
|
||||
private:
|
||||
asio::io_context &io_conn_;
|
||||
asio::io_context &io_work_;
|
||||
server server_;
|
||||
|
||||
std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
|
||||
connections_;
|
||||
std::mutex mutex_;
|
||||
|
||||
OfflineWebsocketServerConfig config_;
|
||||
|
||||
std::ofstream log_;
|
||||
TeeStream tee_;
|
||||
|
||||
OfflineWebsocketDecoder decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
|
||||
120
sherpa-onnx/csrc/offline-websocket-server.cc
Normal file
120
sherpa-onnx/csrc/offline-websocket-server.cc
Normal file
@@ -0,0 +1,120 @@
|
||||
// sherpa-onnx/csrc/offline-websocket-server.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "asio.hpp"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
static constexpr const char *kUsageMessage = R"(
|
||||
Automatic speech recognition with sherpa-onnx using websocket.
|
||||
|
||||
Usage:
|
||||
|
||||
./bin/sherpa-onnx-offline-websocket-server --help
|
||||
|
||||
(1) For transducer models
|
||||
|
||||
./bin/sherpa-onnx-offline-websocket-server \
|
||||
--port=6006 \
|
||||
--num-work-threads=5 \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--log-file=./log.txt \
|
||||
--max-batch-size=5
|
||||
|
||||
(2) For Paraformer
|
||||
|
||||
./bin/sherpa-onnx-offline-websocket-server \
|
||||
--port=6006 \
|
||||
--num-work-threads=5 \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/model.onnx \
|
||||
--log-file=./log.txt \
|
||||
--max-batch-size=5
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)";
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
|
||||
sherpa_onnx::OfflineWebsocketServerConfig config;
|
||||
|
||||
// the server will listen on this port
|
||||
int32_t port = 6006;
|
||||
|
||||
// size of the thread pool for handling network connections
|
||||
int32_t num_io_threads = 1;
|
||||
|
||||
// size of the thread pool for neural network computation and decoding
|
||||
int32_t num_work_threads = 3;
|
||||
|
||||
po.Register("num-io-threads", &num_io_threads,
|
||||
"Thread pool size for network connections.");
|
||||
|
||||
po.Register("num-work-threads", &num_work_threads,
|
||||
"Thread pool size for for neural network "
|
||||
"computation and decoding.");
|
||||
|
||||
po.Register("port", &port, "The port on which the server will listen.");
|
||||
|
||||
config.Register(&po);
|
||||
|
||||
if (argc == 1) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 0) {
|
||||
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
config.Validate();
|
||||
|
||||
asio::io_context io_conn; // for network connections
|
||||
asio::io_context io_work; // for neural network and decoding
|
||||
|
||||
sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config);
|
||||
server.Run(port);
|
||||
|
||||
SHERPA_ONNX_LOGE("Started!");
|
||||
SHERPA_ONNX_LOGE("Listening on: %d", port);
|
||||
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
|
||||
|
||||
// give some work to do for the io_work pool
|
||||
auto work_guard = asio::make_work_guard(io_work);
|
||||
|
||||
std::vector<std::thread> io_threads;
|
||||
|
||||
// decrement since the main thread is also used for network communications
|
||||
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
|
||||
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
|
||||
}
|
||||
|
||||
std::vector<std::thread> work_threads;
|
||||
for (int32_t i = 0; i < num_work_threads; ++i) {
|
||||
work_threads.emplace_back([&io_work]() { io_work.run(); });
|
||||
}
|
||||
|
||||
io_conn.run();
|
||||
|
||||
for (auto &t : io_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
for (auto &t : work_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
|
||||
server.Run(port);
|
||||
|
||||
SHERPA_ONNX_LOGE("Started!");
|
||||
SHERPA_ONNX_LOGE("Listening on: %d", port);
|
||||
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user