#!/usr/bin/env python3 # Copyright 2022-2023 Xiaomi Corp. """ A server for non-streaming speech recognition. Non-streaming means you send all the content of the audio at once for recognition. It supports multiple clients sending at the same time. Usage: ./non_streaming_server.py --help Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html for pre-trained models to download. Usage examples: (1) Use a non-streaming transducer model cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26 cd sherpa-onnx-zipformer-en-2023-06-26 git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt (2) Use a non-streaming paraformer cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en cd sherpa-onnx-paraformer-bilingual-zh-en/ git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \ --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt (3) Use a non-streaming CTC model from NeMo cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium cd sherpa-onnx-nemo-ctc-en-conformer-medium git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt (4) Use a non-streaming CTC model from WeNet cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech cd sherpa-onnx-zh-wenet-wenetspeech git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt (5) Use a Whisper model cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en cd sherpa-onnx-whisper-tiny.en git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt (5) Use a tdnn model of the yesno recipe from icefall cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno cd sherpa-onnx-tdnn-yesno git lfs pull --include "*.onnx" python3 ./python-api-examples/non_streaming_server.py \ --sample-rate=8000 \ --feat-dim=23 \ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt ---- To use a certificate so that you can use https, please use python3 ./python-api-examples/non_streaming_server.py \ --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ --certificate=/path/to/your/cert.pem If you don't have a certificate, please run: cd ./python-api-examples/web ./generate-certificate.py It will generate 3 files, one of which is the required `cert.pem`. """ # noqa import argparse import asyncio import http import logging import socket import ssl import sys import warnings from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from typing import Optional, Tuple import numpy as np import sherpa_onnx import websockets from http_server import HttpServer def setup_logger( log_filename: str, log_level: str = "info", use_console: bool = True, ) -> None: """Setup log level. Args: log_filename: The filename to save the log. log_level: The log level to use, e.g., "debug", "info", "warning", "error", "critical" use_console: True to also print logs to console. """ now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}.txt" Path(log_filename).parent.mkdir(parents=True, exist_ok=True) level = logging.ERROR if log_level == "debug": level = logging.DEBUG elif log_level == "info": level = logging.INFO elif log_level == "warning": level = logging.WARNING elif log_level == "critical": level = logging.CRITICAL logging.basicConfig( filename=log_filename, format=formatter, level=level, filemode="w", ) if use_console: console = logging.StreamHandler() console.setLevel(level) console.setFormatter(logging.Formatter(formatter)) logging.getLogger("").addHandler(console) def add_transducer_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--encoder", default="", type=str, help="Path to the transducer encoder model", ) parser.add_argument( "--decoder", default="", type=str, help="Path to the transducer decoder model", ) parser.add_argument( "--joiner", default="", type=str, help="Path to the transducer joiner model", ) def add_paraformer_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--paraformer", default="", type=str, help="Path to the model.onnx from Paraformer", ) def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--nemo-ctc", default="", type=str, help="Path to the model.onnx from NeMo CTC", ) def add_wenet_ctc_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--wenet-ctc", default="", type=str, help="Path to the model.onnx from WeNet CTC", ) def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--tdnn-model", default="", type=str, help="Path to the model.onnx for the tdnn model of the yesno recipe", ) def add_whisper_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--whisper-encoder", default="", type=str, help="Path to whisper encoder model", ) parser.add_argument( "--whisper-decoder", default="", type=str, help="Path to whisper decoder model", ) parser.add_argument( "--whisper-language", default="", type=str, help="""It specifies the spoken language in the input audio file. Example values: en, fr, de, zh, jp. Available languages for multilingual models can be found at https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 If not specified, we infer the language from the input audio file. """, ) parser.add_argument( "--whisper-task", default="transcribe", choices=["transcribe", "translate"], type=str, help="""For multilingual models, if you specify translate, the output will be in English. """, ) parser.add_argument( "--whisper-tail-paddings", default=-1, type=int, help="""Number of tail padding frames. We have removed the 30-second constraint from whisper, so you need to choose the amount of tail padding frames by yourself. Use -1 to use a default value for tail padding. """, ) def add_model_args(parser: argparse.ArgumentParser): add_transducer_model_args(parser) add_paraformer_model_args(parser) add_nemo_ctc_model_args(parser) add_wenet_ctc_model_args(parser) add_tdnn_ctc_model_args(parser) add_whisper_model_args(parser) parser.add_argument( "--tokens", type=str, help="Path to tokens.txt", ) parser.add_argument( "--num-threads", type=int, default=2, help="Number of threads to run the neural network model", ) parser.add_argument( "--provider", type=str, default="cpu", help="Valid values: cpu, cuda, coreml", ) def add_feature_config_args(parser: argparse.ArgumentParser): parser.add_argument( "--sample-rate", type=int, default=16000, help="Sample rate of the data used to train the model. ", ) parser.add_argument( "--feat-dim", type=int, default=80, help="Feature dimension of the model", ) def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( "--decoding-method", type=str, default="greedy_search", help="""Decoding method to use. Current supported methods are: - greedy_search - modified_beam_search (for transducer models only) """, ) add_modified_beam_search_args(parser) def add_modified_beam_search_args(parser: argparse.ArgumentParser): parser.add_argument( "--max-active-paths", type=int, default=4, help="""Used only when --decoding-method is modified_beam_search. It specifies number of active paths to keep during decoding. """, ) def add_hotwords_args(parser: argparse.ArgumentParser): parser.add_argument( "--hotwords-file", type=str, default="", help=""" The file containing hotwords, one words/phrases per line, and for each phrase the bpe/cjkchar are separated by a space. For example: ▁HE LL O ▁WORLD 你 好 世 界 """, ) parser.add_argument( "--hotwords-score", type=float, default=1.5, help=""" The hotword score of each token for biasing word/phrase. Used only if --hotwords-file is given. """, ) def add_blank_penalty_args(parser: argparse.ArgumentParser): parser.add_argument( "--blank-penalty", type=float, default=0.0, help=""" The penalty applied on blank symbol during decoding. Note: It is a positive value that would be applied to logits like this `logits[:, 0] -= blank_penalty` (suppose logits.shape is [batch_size, vocab] and blank id is 0). """, ) def check_args(args): if not Path(args.tokens).is_file(): raise ValueError(f"{args.tokens} does not exist") if args.decoding_method not in ( "greedy_search", "modified_beam_search", ): raise ValueError(f"Unsupported decoding method {args.decoding_method}") if args.decoding_method == "modified_beam_search": assert args.num_active_paths > 0, args.num_active_paths assert Path(args.encoder).is_file(), args.encoder assert Path(args.decoder).is_file(), args.decoder assert Path(args.joiner).is_file(), args.joiner if args.hotwords_file != "": assert args.decoding_method == "modified_beam_search", args.decoding_method assert Path(args.hotwords_file).is_file(), args.hotwords_file def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) add_model_args(parser) add_feature_config_args(parser) add_decoding_args(parser) add_hotwords_args(parser) add_blank_penalty_args(parser) parser.add_argument( "--port", type=int, default=6006, help="The server will listen on this port", ) parser.add_argument( "--max-batch-size", type=int, default=3, help="""Max batch size for computation. Note if there are not enough requests in the queue, it will wait for max_wait_ms time. After that, even if there are not enough requests, it still sends the available requests in the queue for computation. """, ) parser.add_argument( "--max-wait-ms", type=float, default=5, help="""Max time in millisecond to wait to build batches for inference. If there are not enough requests in the feature queue to build a batch of max_batch_size, it waits up to this time before fetching available requests for computation. """, ) parser.add_argument( "--nn-pool-size", type=int, default=1, help="Number of threads for NN computation and decoding.", ) parser.add_argument( "--max-message-size", type=int, default=(1 << 20), help="""Max message size in bytes. The max size per message cannot exceed this limit. """, ) parser.add_argument( "--max-queue-size", type=int, default=32, help="Max number of messages in the queue for each connection.", ) parser.add_argument( "--max-active-connections", type=int, default=200, help="""Maximum number of active connections. The server will refuse to accept new connections once the current number of active connections equals to this limit. """, ) parser.add_argument( "--certificate", type=str, help="""Path to the X.509 certificate. You need it only if you want to use a secure websocket connection, i.e., use wss:// instead of ws://. You can use ./web/generate-certificate.py to generate the certificate `cert.pem`. Note ./web/generate-certificate.py will generate three files but you only need to pass the generated cert.pem to this option. """, ) parser.add_argument( "--doc-root", type=str, default="./python-api-examples/web", help="Path to the web root", ) return parser.parse_args() class NonStreamingServer: def __init__( self, recognizer: sherpa_onnx.OfflineRecognizer, max_batch_size: int, max_wait_ms: float, nn_pool_size: int, max_message_size: int, max_queue_size: int, max_active_connections: int, doc_root: str, certificate: Optional[str] = None, ): """ Args: recognizer: An instance of the sherpa_onnx.OfflineRecognizer. max_batch_size: Max batch size for inference. max_wait_ms: Max wait time in milliseconds in order to build a batch of `max_batch_size`. nn_pool_size: Number of threads for the thread pool that is used for NN computation and decoding. max_message_size: Max size in bytes per message. max_queue_size: Max number of messages in the queue for each connection. max_active_connections: Max number of active connections. Once number of active client equals to this limit, the server refuses to accept new connections. doc_root: Path to the directory where files like index.html for the HTTP server locate. certificate: Optional. If not None, it will use secure websocket. You can use ./web/generate-certificate.py to generate it (the default generated filename is `cert.pem`). """ self.recognizer = recognizer self.certificate = certificate self.http_server = HttpServer(doc_root) self.nn_pool_size = nn_pool_size self.nn_pool = ThreadPoolExecutor( max_workers=nn_pool_size, thread_name_prefix="nn", ) self.stream_queue = asyncio.Queue() self.max_wait_ms = max_wait_ms self.max_batch_size = max_batch_size self.max_message_size = max_message_size self.max_queue_size = max_queue_size self.max_active_connections = max_active_connections self.current_active_connections = 0 self.sample_rate = int(recognizer.config.feat_config.sampling_rate) async def process_request( self, path: str, request_headers: websockets.Headers, ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: if "sec-websocket-key" not in request_headers: # This is a normal HTTP request if path == "/": path = "/index.html" if path[-1] == "?": path = path[:-1] if path == "/streaming_record.html": response = r"""