Add non-streaming websocket server for python (#259)
This commit is contained in:
16
.github/workflows/test-pip-install.yaml
vendored
16
.github/workflows/test-pip-install.yaml
vendored
@@ -23,12 +23,12 @@ permissions:
|
||||
jobs:
|
||||
test_pip_install:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: Test pip install on ${{ matrix.os }}
|
||||
name: ${{ matrix.os }} ${{ matrix.python-version }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -50,3 +50,15 @@ jobs:
|
||||
run: |
|
||||
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
|
||||
python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)"
|
||||
|
||||
sherpa-onnx --help
|
||||
sherpa-onnx-offline --help
|
||||
|
||||
sherpa-onnx-microphone --help
|
||||
sherpa-onnx-microphone-offline --help
|
||||
|
||||
sherpa-onnx-offline-websocket-server --help
|
||||
sherpa-onnx-offline-websocket-client --help
|
||||
|
||||
sherpa-onnx-online-websocket-server --help
|
||||
sherpa-onnx-online-websocket-client --help
|
||||
|
||||
174
.github/workflows/test-python-offline-websocket-server.yaml
vendored
Normal file
174
.github/workflows/test-python-offline-websocket-server.yaml
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
name: Python offline websocket server
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
concurrency:
|
||||
group: python-offline-websocket-server-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
python_offline_websocket_server:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --no-deps --verbose .
|
||||
python3 -m pip install websockets
|
||||
|
||||
|
||||
- name: Start server for transducer models
|
||||
if: matrix.model_type == 'transducer'
|
||||
shell: bash
|
||||
run: |
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
|
||||
cd sherpa-onnx-zipformer-en-2023-06-26
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt &
|
||||
|
||||
echo "sleep 10 seconds to wait the server start"
|
||||
sleep 10
|
||||
|
||||
- name: Start client for transducer models
|
||||
if: matrix.model_type == 'transducer'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \
|
||||
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||
|
||||
- name: Start server for paraformer models
|
||||
if: matrix.model_type == 'paraformer'
|
||||
shell: bash
|
||||
run: |
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
|
||||
cd sherpa-onnx-paraformer-zh-2023-03-28
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
|
||||
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
|
||||
|
||||
echo "sleep 10 seconds to wait the server start"
|
||||
sleep 10
|
||||
|
||||
- name: Start client for paraformer models
|
||||
if: matrix.model_type == 'paraformer'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
|
||||
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
|
||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
|
||||
|
||||
- name: Start server for nemo_ctc models
|
||||
if: matrix.model_type == 'nemo_ctc'
|
||||
shell: bash
|
||||
run: |
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
|
||||
cd sherpa-onnx-nemo-ctc-en-conformer-medium
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
|
||||
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt &
|
||||
|
||||
echo "sleep 10 seconds to wait the server start"
|
||||
sleep 10
|
||||
|
||||
- name: Start client for nemo_ctc models
|
||||
if: matrix.model_type == 'nemo_ctc'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
||||
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav
|
||||
|
||||
- name: Start server for whisper models
|
||||
if: matrix.model_type == 'whisper'
|
||||
shell: bash
|
||||
run: |
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
|
||||
cd sherpa-onnx-whisper-tiny.en
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt &
|
||||
|
||||
echo "sleep 10 seconds to wait the server start"
|
||||
sleep 10
|
||||
|
||||
- name: Start client for whisper models
|
||||
if: matrix.model_type == 'whisper'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
|
||||
|
||||
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
|
||||
73
.github/workflows/test-python-online-websocket-server.yaml
vendored
Normal file
73
.github/workflows/test-python-online-websocket-server.yaml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: Python online websocket server
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
concurrency:
|
||||
group: python-online-websocket-server-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
python_online_websocket_server:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: ${{ matrix.os }} ${{ matrix.python-version }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||
model_type: ["transducer"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --no-deps --verbose .
|
||||
python3 -m pip install websockets
|
||||
|
||||
|
||||
- name: Start server for transducer models
|
||||
if: matrix.model_type == 'transducer'
|
||||
shell: bash
|
||||
run: |
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||
cd sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--encoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \
|
||||
--decoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \
|
||||
--joiner ./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt &
|
||||
echo "sleep 10 seconds to wait the server start"
|
||||
sleep 10
|
||||
|
||||
- name: Start client for transducer models
|
||||
if: matrix.model_type == 'transducer'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
|
||||
@@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
project(sherpa-onnx)
|
||||
|
||||
set(SHERPA_ONNX_VERSION "1.7.1")
|
||||
set(SHERPA_ONNX_VERSION "1.7.2")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
|
||||
9
c-api-examples/README.md
Normal file
9
c-api-examples/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains C API examples for [sherpa-onnx][sherpa-onnx].
|
||||
|
||||
Please refer to the documentation
|
||||
https://k2-fsa.github.io/sherpa/onnx/c-api/index.html
|
||||
for details.
|
||||
|
||||
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||
9
dotnet-examples/README.md
Normal file
9
dotnet-examples/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains C# API examples for [sherpa-onnx][sherpa-onnx].
|
||||
|
||||
Please refer to the documentation
|
||||
https://k2-fsa.github.io/sherpa/onnx/csharp-api/index.html
|
||||
for details.
|
||||
|
||||
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||
9
go-api-examples/README.md
Normal file
9
go-api-examples/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains Go API examples for [sherpa-onnx][sherpa-onnx].
|
||||
|
||||
Please refer to the documentation
|
||||
https://k2-fsa.github.io/sherpa/onnx/go-api/index.html
|
||||
for details.
|
||||
|
||||
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||
835
python-api-examples/non_streaming_server.py
Executable file
835
python-api-examples/non_streaming_server.py
Executable file
@@ -0,0 +1,835 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp.
|
||||
"""
|
||||
A server for non-streaming speech recognition. Non-streaming means you send all
|
||||
the content of the audio at once for recognition.
|
||||
|
||||
It supports multiple clients sending at the same time.
|
||||
|
||||
Usage:
|
||||
./non_streaming_server.py --help
|
||||
|
||||
Please refer to
|
||||
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
|
||||
|
||||
for pre-trained models to download.
|
||||
|
||||
Usage examples:
|
||||
|
||||
(1) Use a non-streaming transducer model
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26
|
||||
cd sherpa-onnx-zipformer-en-2023-06-26
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt
|
||||
|
||||
(2) Use a non-streaming paraformer
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
|
||||
cd sherpa-onnx-paraformer-zh-2023-03-28
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
|
||||
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
|
||||
|
||||
(3) Use a non-streaming CTC model from NeMo
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium
|
||||
cd sherpa-onnx-nemo-ctc-en-conformer-medium
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
|
||||
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
|
||||
|
||||
(4) Use a Whisper model
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
|
||||
cd sherpa-onnx-whisper-tiny.en
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
||||
|
||||
----
|
||||
|
||||
To use a certificate so that you can use https, please use
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
||||
--certificate=/path/to/your/cert.pem
|
||||
|
||||
If you don't have a certificate, please run:
|
||||
|
||||
cd ./python-api-examples/web
|
||||
./generate-certificate.py
|
||||
|
||||
It will generate 3 files, one of which is the required `cert.pem`.
|
||||
""" # noqa
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
import websockets
|
||||
|
||||
from http_server import HttpServer
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: str,
|
||||
log_level: str = "info",
|
||||
use_console: bool = True,
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
Args:
|
||||
log_filename:
|
||||
The filename to save the log.
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
use_console:
|
||||
True to also print logs to console.
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}.txt"
|
||||
|
||||
Path(log_filename).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
level = logging.ERROR
|
||||
if log_level == "debug":
|
||||
level = logging.DEBUG
|
||||
elif log_level == "info":
|
||||
level = logging.INFO
|
||||
elif log_level == "warning":
|
||||
level = logging.WARNING
|
||||
elif log_level == "critical":
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=level,
|
||||
filemode="w",
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def add_transducer_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the transducer encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the transducer decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the transducer joiner model",
|
||||
)
|
||||
|
||||
|
||||
def add_paraformer_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
|
||||
def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
|
||||
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
add_transducer_model_args(parser)
|
||||
add_paraformer_model_args(parser)
|
||||
add_nemo_ctc_model_args(parser)
|
||||
add_whisper_model_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of threads to run the neural network model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
|
||||
def add_feature_config_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Sample rate of the data used to train the model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feat-dim",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Feature dimension of the model",
|
||||
)
|
||||
|
||||
|
||||
def add_decoding_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Decoding method to use. Current supported methods are:
|
||||
- greedy_search
|
||||
- modified_beam_search (for transducer models only)
|
||||
""",
|
||||
)
|
||||
|
||||
add_modified_beam_search_args(parser)
|
||||
|
||||
|
||||
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--max-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is modified_beam_search.
|
||||
It specifies number of active paths to keep during decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def check_args(args):
|
||||
if not Path(args.tokens).is_file():
|
||||
raise ValueError(f"{args.tokens} does not exist")
|
||||
|
||||
if args.decoding_method not in (
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
):
|
||||
raise ValueError(f"Unsupported decoding method {args.decoding_method}")
|
||||
|
||||
if args.decoding_method == "modified_beam_search":
|
||||
assert args.num_active_paths > 0, args.num_active_paths
|
||||
assert Path(args.encoder).is_file(), args.encoder
|
||||
assert Path(args.decoder).is_file(), args.decoder
|
||||
assert Path(args.joiner).is_file(), args.joiner
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
add_model_args(parser)
|
||||
add_feature_config_args(parser)
|
||||
add_decoding_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=6006,
|
||||
help="The server will listen on this port",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Max batch size for computation. Note if there are not enough
|
||||
requests in the queue, it will wait for max_wait_ms time. After that,
|
||||
even if there are not enough requests, it still sends the
|
||||
available requests in the queue for computation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-wait-ms",
|
||||
type=float,
|
||||
default=5,
|
||||
help="""Max time in millisecond to wait to build batches for inference.
|
||||
If there are not enough requests in the feature queue to build a batch
|
||||
of max_batch_size, it waits up to this time before fetching available
|
||||
requests for computation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-pool-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for NN computation and decoding.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-message-size",
|
||||
type=int,
|
||||
default=(1 << 20),
|
||||
help="""Max message size in bytes.
|
||||
The max size per message cannot exceed this limit.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-queue-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Max number of messages in the queue for each connection.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-active-connections",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""Maximum number of active connections. The server will refuse
|
||||
to accept new connections once the current number of active connections
|
||||
equals to this limit.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--certificate",
|
||||
type=str,
|
||||
help="""Path to the X.509 certificate. You need it only if you want to
|
||||
use a secure websocket connection, i.e., use wss:// instead of ws://.
|
||||
You can use ./web/generate-certificate.py
|
||||
to generate the certificate `cert.pem`.
|
||||
Note ./web/generate-certificate.py will generate three files but you
|
||||
only need to pass the generated cert.pem to this option.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--doc-root",
|
||||
type=str,
|
||||
default="./python-api-examples/web",
|
||||
help="Path to the web root",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class NonStreamingServer:
|
||||
def __init__(
|
||||
self,
|
||||
recognizer: sherpa_onnx.OfflineRecognizer,
|
||||
max_batch_size: int,
|
||||
max_wait_ms: float,
|
||||
nn_pool_size: int,
|
||||
max_message_size: int,
|
||||
max_queue_size: int,
|
||||
max_active_connections: int,
|
||||
doc_root: str,
|
||||
certificate: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
recognizer:
|
||||
An instance of the sherpa_onnx.OfflineRecognizer.
|
||||
max_batch_size:
|
||||
Max batch size for inference.
|
||||
max_wait_ms:
|
||||
Max wait time in milliseconds in order to build a batch of
|
||||
`max_batch_size`.
|
||||
nn_pool_size:
|
||||
Number of threads for the thread pool that is used for NN
|
||||
computation and decoding.
|
||||
max_message_size:
|
||||
Max size in bytes per message.
|
||||
max_queue_size:
|
||||
Max number of messages in the queue for each connection.
|
||||
max_active_connections:
|
||||
Max number of active connections. Once number of active client
|
||||
equals to this limit, the server refuses to accept new connections.
|
||||
doc_root:
|
||||
Path to the directory where files like index.html for the HTTP
|
||||
server locate.
|
||||
certificate:
|
||||
Optional. If not None, it will use secure websocket.
|
||||
You can use ./web/generate-certificate.py to generate
|
||||
it (the default generated filename is `cert.pem`).
|
||||
"""
|
||||
self.recognizer = recognizer
|
||||
|
||||
self.certificate = certificate
|
||||
self.http_server = HttpServer(doc_root)
|
||||
|
||||
self.nn_pool = ThreadPoolExecutor(
|
||||
max_workers=nn_pool_size,
|
||||
thread_name_prefix="nn",
|
||||
)
|
||||
|
||||
self.stream_queue = asyncio.Queue()
|
||||
|
||||
self.max_wait_ms = max_wait_ms
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_message_size = max_message_size
|
||||
self.max_queue_size = max_queue_size
|
||||
self.max_active_connections = max_active_connections
|
||||
|
||||
self.current_active_connections = 0
|
||||
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
path: str,
|
||||
request_headers: websockets.Headers,
|
||||
) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]:
|
||||
if "sec-websocket-key" not in request_headers:
|
||||
# This is a normal HTTP request
|
||||
if path == "/":
|
||||
path = "/index.html"
|
||||
if path[-1] == "?":
|
||||
path = path[:-1]
|
||||
|
||||
if path == "/streaming_record.html":
|
||||
response = r"""
|
||||
<!doctype html><html><head>
|
||||
<title>Speech recognition with next-gen Kaldi</title><body>
|
||||
<h2>Only
|
||||
<a href="/upload.html">/upload.html</a>
|
||||
and
|
||||
<a href="/offline_record.html">/offline_record.html</a>
|
||||
is available for the non-streaming server.<h2>
|
||||
<br/>
|
||||
<br/>
|
||||
Go back to <a href="/upload.html">/upload.html</a>
|
||||
or <a href="/offline_record.html">/offline_record.html</a>
|
||||
</body></head></html>
|
||||
"""
|
||||
found = True
|
||||
mime_type = "text/html"
|
||||
else:
|
||||
found, response, mime_type = self.http_server.process_request(path)
|
||||
if isinstance(response, str):
|
||||
response = response.encode("utf-8")
|
||||
|
||||
if not found:
|
||||
status = http.HTTPStatus.NOT_FOUND
|
||||
else:
|
||||
status = http.HTTPStatus.OK
|
||||
header = {"Content-Type": mime_type}
|
||||
return status, header, response
|
||||
|
||||
if self.current_active_connections < self.max_active_connections:
|
||||
self.current_active_connections += 1
|
||||
return None
|
||||
|
||||
# Refuse new connections
|
||||
status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
|
||||
header = {"Hint": "The server is overloaded. Please retry later."}
|
||||
response = b"The server is busy. Please retry later."
|
||||
|
||||
return status, header, response
|
||||
|
||||
async def run(self, port: int):
|
||||
logging.info("started")
|
||||
|
||||
task = asyncio.create_task(self.stream_consumer_task())
|
||||
|
||||
if self.certificate:
|
||||
logging.info(f"Using certificate: {self.certificate}")
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_context.load_cert_chain(self.certificate)
|
||||
else:
|
||||
ssl_context = None
|
||||
logging.info("No certificate provided")
|
||||
|
||||
async with websockets.serve(
|
||||
self.handle_connection,
|
||||
host="",
|
||||
port=port,
|
||||
max_size=self.max_message_size,
|
||||
max_queue=self.max_queue_size,
|
||||
process_request=self.process_request,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
ip_list = ["localhost"]
|
||||
if ssl_context:
|
||||
ip_list += ["0.0.0.0", "127.0.0.1"]
|
||||
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
||||
|
||||
proto = "http://" if ssl_context is None else "https://"
|
||||
s = "Please visit one of the following addresses:\n\n"
|
||||
for p in ip_list:
|
||||
s += " " + proto + p + f":{port}" "\n"
|
||||
logging.info(s)
|
||||
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
await task # not reachable
|
||||
|
||||
async def recv_audio_samples(
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[float]]:
|
||||
"""Receive a tensor from the client.
|
||||
|
||||
The message from the client is a **bytes** buffer.
|
||||
|
||||
The first message can be either "Done" meaning the client won't send
|
||||
anything in the future or it can be a buffer containing 8 bytes.
|
||||
The first 4 bytes in little endian specifies the sample
|
||||
rate of the audio samples; the second 4 bytes in little endian specifies
|
||||
the number of bytes in the audio file, which will be sent by the client
|
||||
in the subsequent messages.
|
||||
Since there is a limit in the message size posed by the websocket
|
||||
protocol, the client may send the audio file in multiple messages if the
|
||||
audio file is very large.
|
||||
|
||||
The second and remaining messages contain audio samples.
|
||||
|
||||
Please refer to ./offline-websocket-client-decode-files-paralell.py
|
||||
and ./offline-websocket-client-decode-files-sequential.py
|
||||
for how the client sends the message.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
The socket for communicating with the client.
|
||||
Returns:
|
||||
Return a containing:
|
||||
- 1-D np.float32 array containing the audio samples
|
||||
- sample rate of the audio samples
|
||||
or return (None, None) indicating the end of utterance.
|
||||
"""
|
||||
header = await socket.recv()
|
||||
if header == "Done":
|
||||
return None, None
|
||||
|
||||
assert len(header) >= 8, (
|
||||
"The first message should contain at least 8 bytes."
|
||||
+ f"Given {len(header)}"
|
||||
)
|
||||
|
||||
sample_rate = int.from_bytes(header[:4], "little", signed=True)
|
||||
expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True)
|
||||
|
||||
received = []
|
||||
num_received_bytes = 0
|
||||
if len(header) > 8:
|
||||
received.append(header[8:])
|
||||
num_received_bytes += len(header) - 8
|
||||
|
||||
if num_received_bytes < expected_num_bytes:
|
||||
async for message in socket:
|
||||
received.append(message)
|
||||
num_received_bytes += len(message)
|
||||
if num_received_bytes >= expected_num_bytes:
|
||||
break
|
||||
|
||||
assert num_received_bytes == expected_num_bytes, (
|
||||
num_received_bytes,
|
||||
expected_num_bytes,
|
||||
)
|
||||
|
||||
samples = b"".join(received)
|
||||
array = np.frombuffer(samples, dtype=np.float32)
|
||||
return array, sample_rate
|
||||
|
||||
async def stream_consumer_task(self):
|
||||
"""This function extracts streams from the queue, batches them up, sends
|
||||
them to the RNN-T model for computation and decoding.
|
||||
"""
|
||||
while True:
|
||||
if self.stream_queue.empty():
|
||||
await asyncio.sleep(self.max_wait_ms / 1000)
|
||||
continue
|
||||
|
||||
batch = []
|
||||
try:
|
||||
while len(batch) < self.max_batch_size:
|
||||
item = self.stream_queue.get_nowait()
|
||||
|
||||
batch.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
stream_list = [b[0] for b in batch]
|
||||
future_list = [b[1] for b in batch]
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
self.nn_pool,
|
||||
self.recognizer.decode_streams,
|
||||
stream_list,
|
||||
)
|
||||
|
||||
for f in future_list:
|
||||
self.stream_queue.task_done()
|
||||
f.set_result(None)
|
||||
|
||||
async def compute_and_decode(
|
||||
self,
|
||||
stream: sherpa_onnx.OfflineStream,
|
||||
) -> None:
|
||||
"""Put the stream into the queue and wait it to be processed by the
|
||||
consumer task.
|
||||
|
||||
Args:
|
||||
stream:
|
||||
The stream to be processed. Note: It is changed in-place.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
await self.stream_queue.put((stream, future))
|
||||
await future
|
||||
|
||||
async def handle_connection(
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
):
|
||||
"""Receive audio samples from the client, process it, and sends
|
||||
deocoding result back to the client.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
The socket for communicating with the client.
|
||||
"""
|
||||
try:
|
||||
await self.handle_connection_impl(socket)
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logging.info(f"{socket.remote_address} disconnected")
|
||||
finally:
|
||||
# Decrement so that it can accept new connections
|
||||
self.current_active_connections -= 1
|
||||
|
||||
logging.info(
|
||||
f"Disconnected: {socket.remote_address}. "
|
||||
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
|
||||
)
|
||||
|
||||
async def handle_connection_impl(
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
):
|
||||
"""Receive audio samples from the client, process it, and send
|
||||
decoding results back to the client.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
The socket for communicating with the client.
|
||||
"""
|
||||
logging.info(
|
||||
f"Connected: {socket.remote_address}. "
|
||||
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
|
||||
)
|
||||
|
||||
while True:
|
||||
stream = self.recognizer.create_stream()
|
||||
samples, sample_rate = await self.recv_audio_samples(socket)
|
||||
if samples is None:
|
||||
break
|
||||
# stream.accept_samples() runs in the main thread
|
||||
|
||||
stream.accept_waveform(sample_rate, samples)
|
||||
|
||||
await self.compute_and_decode(stream)
|
||||
result = stream.result.text
|
||||
logging.info(f"result: {result}")
|
||||
|
||||
if result:
|
||||
await socket.send(result)
|
||||
else:
|
||||
# If result is an empty string, send something to the client.
|
||||
# Otherwise, socket.send() is a no-op and the client will
|
||||
# wait for a reply indefinitely.
|
||||
await socket.send("<EMPTY>")
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
max_active_paths=args.max_active_paths,
|
||||
)
|
||||
elif args.paraformer:
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
elif args.nemo_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
assert_file_exists(args.nemo_ctc)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=args.nemo_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
elif args.whisper_encoder:
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||
encoder=args.whisper_encoder,
|
||||
decoder=args.whisper_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please specify at least one model")
|
||||
|
||||
return recognizer
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
check_args(args)
|
||||
|
||||
recognizer = create_recognizer(args)
|
||||
|
||||
port = args.port
|
||||
max_wait_ms = args.max_wait_ms
|
||||
max_batch_size = args.max_batch_size
|
||||
nn_pool_size = args.nn_pool_size
|
||||
max_message_size = args.max_message_size
|
||||
max_queue_size = args.max_queue_size
|
||||
max_active_connections = args.max_active_connections
|
||||
certificate = args.certificate
|
||||
doc_root = args.doc_root
|
||||
|
||||
if certificate and not Path(certificate).is_file():
|
||||
raise ValueError(f"{certificate} does not exist")
|
||||
|
||||
if not Path(doc_root).is_dir():
|
||||
raise ValueError(f"Directory {doc_root} does not exist")
|
||||
|
||||
non_streaming_server = NonStreamingServer(
|
||||
recognizer=recognizer,
|
||||
max_wait_ms=max_wait_ms,
|
||||
max_batch_size=max_batch_size,
|
||||
nn_pool_size=nn_pool_size,
|
||||
max_message_size=max_message_size,
|
||||
max_queue_size=max_queue_size,
|
||||
max_active_connections=max_active_connections,
|
||||
certificate=certificate,
|
||||
doc_root=doc_root,
|
||||
)
|
||||
asyncio.run(non_streaming_server.run(port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log_filename = "log/log-non-streaming-server"
|
||||
setup_logger(log_filename)
|
||||
main()
|
||||
@@ -119,7 +119,13 @@ async def run(
|
||||
buf += (samples.size * 4).to_bytes(4, byteorder="little")
|
||||
buf += samples.tobytes()
|
||||
|
||||
await websocket.send(buf)
|
||||
payload_len = 10240
|
||||
while len(buf) > payload_len:
|
||||
await websocket.send(buf[:payload_len])
|
||||
buf = buf[payload_len:]
|
||||
|
||||
if buf:
|
||||
await websocket.send(buf)
|
||||
|
||||
decoding_results = await websocket.recv()
|
||||
logging.info(f"{wave_filename}\n{decoding_results}")
|
||||
|
||||
@@ -116,11 +116,18 @@ async def run(
|
||||
assert isinstance(sample_rate, int)
|
||||
assert samples.dtype == np.float32, samples.dtype
|
||||
assert samples.ndim == 1, samples.dim
|
||||
|
||||
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
|
||||
buf += (samples.size * 4).to_bytes(4, byteorder="little")
|
||||
buf += samples.tobytes()
|
||||
|
||||
await websocket.send(buf)
|
||||
payload_len = 10240
|
||||
while len(buf) > payload_len:
|
||||
await websocket.send(buf[:payload_len])
|
||||
buf = buf[payload_len:]
|
||||
|
||||
if buf:
|
||||
await websocket.send(buf)
|
||||
|
||||
decoding_results = await websocket.recv()
|
||||
print(decoding_results)
|
||||
|
||||
@@ -15,10 +15,9 @@ Usage:
|
||||
|
||||
(Note: You have to first start the server before starting the client)
|
||||
|
||||
You can find the server at
|
||||
You can find the c++ server at
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
|
||||
|
||||
Note: The server is implemented in C++.
|
||||
or use the python server ./python-api-examples/streaming_server.py
|
||||
|
||||
There is also a C++ version of the client. Please see
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
|
||||
@@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
|
||||
last_message = message
|
||||
logging.info(message)
|
||||
else:
|
||||
return last_message
|
||||
break
|
||||
return last_message
|
||||
|
||||
|
||||
async def run(
|
||||
@@ -142,6 +142,7 @@ async def run(
|
||||
|
||||
await websocket.send(d)
|
||||
|
||||
# Simulate streaming. You can remove the sleep if you want
|
||||
await asyncio.sleep(seconds_per_message) # in seconds
|
||||
|
||||
start += samples_per_message
|
||||
|
||||
@@ -12,10 +12,9 @@ Usage:
|
||||
|
||||
(Note: You have to first start the server before starting the client)
|
||||
|
||||
You can find the server at
|
||||
You can find the C++ server at
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
|
||||
|
||||
Note: The server is implemented in C++.
|
||||
or use the python server ./python-api-examples/streaming_server.py
|
||||
|
||||
There is also a C++ version of the client. Please see
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
|
||||
|
||||
@@ -13,11 +13,37 @@ Usage:
|
||||
|
||||
Example:
|
||||
|
||||
(1) Without a certificate
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
||||
|
||||
(2) With a certificate
|
||||
|
||||
(a) Generate a certificate first:
|
||||
|
||||
cd python-api-examples/web
|
||||
./generate-certificate.py
|
||||
cd ../..
|
||||
|
||||
(b) Start the server
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
|
||||
--certificate ./python-api-examples/web/cert.pem
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
|
||||
to download pre-trained models.
|
||||
|
||||
The model in the above help messages is from
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -35,6 +61,7 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import websockets
|
||||
|
||||
from http_server import HttpServer
|
||||
|
||||
|
||||
@@ -269,8 +296,8 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.",
|
||||
default=2,
|
||||
help="Number of threads to run the neural network model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -278,8 +305,10 @@ def get_args():
|
||||
type=str,
|
||||
help="""Path to the X.509 certificate. You need it only if you want to
|
||||
use a secure websocket connection, i.e., use wss:// instead of ws://.
|
||||
You can use sherpa/bin/web/generate-certificate.py
|
||||
You can use ./web/generate-certificate.py
|
||||
to generate the certificate `cert.pem`.
|
||||
Note ./web/generate-certificate.py will generate three files but you
|
||||
only need to pass the generated cert.pem to this option.
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -287,7 +316,7 @@ def get_args():
|
||||
"--doc-root",
|
||||
type=str,
|
||||
default="./python-api-examples/web",
|
||||
help="""Path to the web root""",
|
||||
help="Path to the web root",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
||||
encoder=args.encoder_model,
|
||||
decoder=args.decoder_model,
|
||||
joiner=args.joiner_model,
|
||||
num_threads=1,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
max_active_paths=args.num_active_paths,
|
||||
enable_endpoint_detection=args.use_endpoint != 0,
|
||||
@@ -359,7 +388,7 @@ class StreamingServer(object):
|
||||
server locate.
|
||||
certificate:
|
||||
Optional. If not None, it will use secure websocket.
|
||||
You can use ./sherpa/bin/web/generate-certificate.py to generate
|
||||
You can use ./web/generate-certificate.py to generate
|
||||
it (the default generated filename is `cert.pem`).
|
||||
"""
|
||||
self.recognizer = recognizer
|
||||
@@ -373,6 +402,7 @@ class StreamingServer(object):
|
||||
)
|
||||
|
||||
self.stream_queue = asyncio.Queue()
|
||||
|
||||
self.max_wait_ms = max_wait_ms
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_message_size = max_message_size
|
||||
@@ -382,11 +412,10 @@ class StreamingServer(object):
|
||||
self.current_active_connections = 0
|
||||
|
||||
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
|
||||
self.decoding_method = recognizer.config.decoding_method
|
||||
|
||||
async def stream_consumer_task(self):
|
||||
"""This function extracts streams from the queue, batches them up, sends
|
||||
them to the RNN-T model for computation and decoding.
|
||||
them to the neural network model for computation and decoding.
|
||||
"""
|
||||
while True:
|
||||
if self.stream_queue.empty():
|
||||
@@ -442,7 +471,22 @@ class StreamingServer(object):
|
||||
# This is a normal HTTP request
|
||||
if path == "/":
|
||||
path = "/index.html"
|
||||
found, response, mime_type = self.http_server.process_request(path)
|
||||
|
||||
if path in ("/upload.html", "/offline_record.html"):
|
||||
response = r"""
|
||||
<!doctype html><html><head>
|
||||
<title>Speech recognition with next-gen Kaldi</title><body>
|
||||
<h2>Only /streaming_record.html is available for the streaming server.<h2>
|
||||
<br/>
|
||||
<br/>
|
||||
Go back to <a href="/streaming_record.html">/streaming_record.html</a>
|
||||
</body></head></html>
|
||||
"""
|
||||
found = True
|
||||
mime_type = "text/html"
|
||||
else:
|
||||
found, response, mime_type = self.http_server.process_request(path)
|
||||
|
||||
if isinstance(response, str):
|
||||
response = response.encode("utf-8")
|
||||
|
||||
@@ -484,12 +528,21 @@ class StreamingServer(object):
|
||||
process_request=self.process_request,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]
|
||||
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
||||
ip_list = ["localhost"]
|
||||
if ssl_context:
|
||||
ip_list += ["0.0.0.0", "127.0.0.1"]
|
||||
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
||||
proto = "http://" if ssl_context is None else "https://"
|
||||
s = "Please visit one of the following addresses:\n\n"
|
||||
for p in ip_list:
|
||||
s += " " + proto + p + f":{port}" "\n"
|
||||
|
||||
if not ssl_context:
|
||||
s += "\nSince you are not providing a certificate, you cannot "
|
||||
s += "use your microphone from within the browser using "
|
||||
s += "public IP addresses. Only localhost can be used."
|
||||
s += "You also cannot use 0.0.0.0 or 127.0.0.1"
|
||||
|
||||
logging.info(s)
|
||||
|
||||
await asyncio.Future() # run forever
|
||||
@@ -525,7 +578,7 @@ class StreamingServer(object):
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
):
|
||||
"""Receive audio samples from the client, process it, and send
|
||||
deocoding result back to the client.
|
||||
decoding result back to the client.
|
||||
|
||||
Args:
|
||||
socket:
|
||||
@@ -560,8 +613,6 @@ class StreamingServer(object):
|
||||
self.recognizer.reset(stream)
|
||||
segment += 1
|
||||
|
||||
print(message)
|
||||
|
||||
await socket.send(json.dumps(message))
|
||||
|
||||
tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
|
||||
@@ -583,7 +634,7 @@ class StreamingServer(object):
|
||||
self,
|
||||
socket: websockets.WebSocketServerProtocol,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Receives a tensor from the client.
|
||||
"""Receive a tensor from the client.
|
||||
|
||||
Each message contains either a bytes buffer containing audio samples
|
||||
in 16 kHz or contains "Done" meaning the end of utterance.
|
||||
@@ -660,6 +711,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log_filename = "log/log-streaming-zipformer"
|
||||
log_filename = "log/log-streaming-server"
|
||||
setup_logger(log_filename)
|
||||
main()
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
# How to use
|
||||
|
||||
```bash
|
||||
git clone https://github.com/k2-fsa/sherpa
|
||||
|
||||
cd sherpa/sherpa/bin/web
|
||||
python3 -m http.server 6009
|
||||
```
|
||||
and then go to <http://localhost:6009>
|
||||
|
||||
You will see a page like the following screenshot:
|
||||
|
||||

|
||||
|
||||
If your server is listening at the port *6006* with address **localhost**,
|
||||
then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it.
|
||||
|
||||
## File descriptions
|
||||
|
||||
### ./css/bootstrap.min.css
|
||||
|
||||
It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css
|
||||
|
||||
### ./js/jquery-3.6.0.min.js
|
||||
|
||||
It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js
|
||||
|
||||
### ./js/popper.min.js
|
||||
|
||||
It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js
|
||||
|
||||
### ./js/bootstrap.min.js
|
||||
|
||||
It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js
|
||||
@@ -35,8 +35,8 @@ Otherwise, you may get the following error from within you browser:
|
||||
|
||||
|
||||
def cert_gen(
|
||||
emailAddress="https://github.com/k2-fsa/k2",
|
||||
commonName="sherpa",
|
||||
emailAddress="https://github.com/k2-fsa/sherpa-onnx",
|
||||
commonName="sherpa-onnx",
|
||||
countryName="CN",
|
||||
localityName="k2-fsa",
|
||||
stateOrProvinceName="k2-fsa",
|
||||
@@ -70,17 +70,13 @@ def cert_gen(
|
||||
cert.set_pubkey(k)
|
||||
cert.sign(k, "sha512")
|
||||
with open(CERT_FILE, "wt") as f:
|
||||
f.write(
|
||||
crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
|
||||
)
|
||||
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
|
||||
with open(KEY_FILE, "wt") as f:
|
||||
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
|
||||
|
||||
with open(ALL_IN_ONE_FILE, "wt") as f:
|
||||
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
|
||||
f.write(
|
||||
crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
|
||||
)
|
||||
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
|
||||
print(f"Generated {CERT_FILE}")
|
||||
print(f"Generated {KEY_FILE}")
|
||||
print(f"Generated {ALL_IN_ONE_FILE}")
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
</ul>
|
||||
|
||||
Code is available at
|
||||
<a href="https://github.com/k2-fsa/sherpa"> https://github.com/k2-fsa/sherpa</a>
|
||||
<a href="https://github.com/k2-fsa/sherpa-onnx"> https://github.com/k2-fsa/sherpa-onnx</a>
|
||||
|
||||
<!-- Optional JavaScript -->
|
||||
<!-- jQuery first, then Popper.js, then Bootstrap JS -->
|
||||
|
||||
@@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips');
|
||||
const canvas = document.getElementById('canvas');
|
||||
const mainSection = document.querySelector('.container');
|
||||
|
||||
recordBtn.disabled = true;
|
||||
stopBtn.disabled = true;
|
||||
|
||||
window.onload = (event) => {
|
||||
@@ -95,9 +96,10 @@ clearBtn.onclick = function() {
|
||||
};
|
||||
|
||||
function send_header(n) {
|
||||
const header = new ArrayBuffer(4);
|
||||
new DataView(header).setInt32(0, n, true /* littleEndian */);
|
||||
socket.send(new Int32Array(header, 0, 1));
|
||||
const header = new ArrayBuffer(8);
|
||||
new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */);
|
||||
new DataView(header).setInt32(4, n, true /* littleEndian */);
|
||||
socket.send(new Int32Array(header, 0, 2));
|
||||
}
|
||||
|
||||
// copied/modified from https://mdn.github.io/web-dictaphone/
|
||||
|
||||
@@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas');
|
||||
const mainSection = document.querySelector('.container');
|
||||
|
||||
stopBtn.disabled = true;
|
||||
recordBtn.disabled = true;
|
||||
|
||||
let audioCtx;
|
||||
const canvasCtx = canvas.getContext('2d');
|
||||
|
||||
@@ -74,9 +74,11 @@ connectBtn.onclick = function() {
|
||||
};
|
||||
|
||||
function send_header(n) {
|
||||
const header = new ArrayBuffer(4);
|
||||
new DataView(header).setInt32(0, n, true /* littleEndian */);
|
||||
socket.send(new Int32Array(header, 0, 1));
|
||||
const header = new ArrayBuffer(8);
|
||||
// assume the uploaded wave is 16000 Hz
|
||||
new DataView(header).setInt32(0, 16000, true /* littleEndian */);
|
||||
new DataView(header).setInt32(4, n, true /* littleEndian */);
|
||||
socket.send(new Int32Array(header, 0, 2));
|
||||
}
|
||||
|
||||
function onFileChange() {
|
||||
|
||||
@@ -33,9 +33,9 @@
|
||||
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
|
||||
</div>
|
||||
<span class="input-group-text" id="ws-protocol">ws://</span>
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
|
||||
<span class="input-group-text">:</span>
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
|
||||
@@ -33,9 +33,9 @@
|
||||
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
|
||||
</div>
|
||||
<span class="input-group-text" id="ws-protocol">ws://</span>
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
|
||||
<span class="input-group-text">:</span>
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
|
||||
@@ -32,9 +32,9 @@
|
||||
<button class="btn btn-block btn-primary" type="button" id="connect">Click me to connect</button>
|
||||
</div>
|
||||
<span class="input-group-text" id="ws-protocol">ws://</span>
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa server IP, e.g., localhost" aria-label="sherpa server IP">
|
||||
<input type="text" id="server-ip" class="form-control" placeholder="Sherpa-onnx server IP, e.g., localhost" aria-label="sherpa-onnx server IP">
|
||||
<span class="input-group-text">:</span>
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa server port, e.g., 6006" aria-label="sherpa server port">
|
||||
<input type="text" id="server-port" class="form-control" placeholder="Sherpa-onnx server port, e.g., 6006" aria-label="sherpa-onnx server port">
|
||||
</div>
|
||||
|
||||
<form>
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from _sherpa_onnx import Display
|
||||
from _sherpa_onnx import Display, OfflineStream, OnlineStream
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
from .online_recognizer import OnlineStream
|
||||
from .offline_recognizer import OfflineRecognizer
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
from .utils import encode_contexts
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class OfflineRecognizer(object):
|
||||
sample_rate: int = 16000,
|
||||
feature_dim: int = 80,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
context_score: float = 1.5,
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
@@ -72,6 +73,9 @@ class OfflineRecognizer(object):
|
||||
Dimension of the feature used to train the model.
|
||||
decoding_method:
|
||||
Valid values: greedy_search, modified_beam_search.
|
||||
max_active_paths:
|
||||
Maximum number of active paths to keep. Used only when
|
||||
decoding_method is modified_beam_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -103,6 +107,7 @@ class OfflineRecognizer(object):
|
||||
context_score=context_score,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@@ -166,6 +171,7 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@@ -229,6 +235,7 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@@ -291,6 +298,7 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||
|
||||
Reference in New Issue
Block a user