From 3f7e0c23acd0cfacdd37d323a8562f6b5568044a Mon Sep 17 00:00:00 2001 From: manyeyes <32889020+manyeyes@users.noreply.github.com> Date: Sun, 2 Apr 2023 13:17:43 +0800 Subject: [PATCH] adding a python api for offline decode (#110) --- python-api-examples/offline-decode-files.py | 240 ++++++++++++++++++ sherpa-onnx/csrc/offline-recognizer.h | 15 +- sherpa-onnx/csrc/offline-stream.h | 16 +- sherpa-onnx/python/csrc/CMakeLists.txt | 5 + .../python/csrc/offline-model-config.cc | 36 +++ .../python/csrc/offline-model-config.h | 16 ++ .../csrc/offline-paraformer-model-config.cc | 24 ++ .../csrc/offline-paraformer-model-config.h | 16 ++ sherpa-onnx/python/csrc/offline-recognizer.cc | 43 ++++ sherpa-onnx/python/csrc/offline-recognizer.h | 16 ++ sherpa-onnx/python/csrc/offline-stream.cc | 61 +++++ sherpa-onnx/python/csrc/offline-stream.h | 16 ++ .../csrc/offline-transducer-model-config.cc | 28 ++ .../csrc/offline-transducer-model-config.h | 16 ++ sherpa-onnx/python/csrc/sherpa-onnx.cc | 11 + sherpa-onnx/python/sherpa_onnx/__init__.py | 1 + .../python/sherpa_onnx/offline_recognizer.py | 167 ++++++++++++ 17 files changed, 712 insertions(+), 15 deletions(-) create mode 100644 python-api-examples/offline-decode-files.py create mode 100644 sherpa-onnx/python/csrc/offline-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-model-config.h create mode 100644 sherpa-onnx/python/csrc/offline-paraformer-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-paraformer-model-config.h create mode 100644 sherpa-onnx/python/csrc/offline-recognizer.cc create mode 100644 sherpa-onnx/python/csrc/offline-recognizer.h create mode 100644 sherpa-onnx/python/csrc/offline-stream.cc create mode 100644 sherpa-onnx/python/csrc/offline-stream.h create mode 100644 sherpa-onnx/python/csrc/offline-transducer-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-transducer-model-config.h create mode 100644 sherpa-onnx/python/sherpa_onnx/offline_recognizer.py diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py new file mode 100644 index 00000000..ed08c393 --- /dev/null +++ b/python-api-examples/offline-decode-files.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +paraformer Usage: + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +transducer Usage: + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download the pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the paraformer model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Sample rate of the feature extractor. Must match the one expected by the model. Note: The input sound files can have a different sample rate from this argument.", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + if len(args.encoder) > 0: + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert len(args.paraformer) == 0, args.paraformer + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug + ) + else: + assert_file_exists(args.paraformer) + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug + ) + + + print("Started!") + start_time = time.time() + + streams = [] + total_duration = 0 + for wave_filename in args.sound_files: + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + + streams.append(s) + + + recognizer.decode_streams(streams) + results = [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 321e2950..9cfd7d58 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -16,20 +16,7 @@ namespace sherpa_onnx { -struct OfflineRecognitionResult { - // Recognition results. - // For English, it consists of space separated words. - // For Chinese, it consists of Chinese words without spaces. - std::string text; - - // Decoded results at the token level. - // For instance, for BPE-based models it consists of a list of BPE tokens. - std::vector tokens; - - /// timestamps.size() == tokens.size() - /// timestamps[i] records the time in seconds when tokens[i] is decoded. - std::vector timestamps; -}; +struct OfflineRecognitionResult; struct OfflineRecognizerConfig { OfflineFeatureExtractorConfig feat_config; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 7686d641..ba0798bf 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -13,7 +13,21 @@ #include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { -struct OfflineRecognitionResult; + +struct OfflineRecognitionResult { + // Recognition results. + // For English, it consists of space separated words. + // For Chinese, it consists of Chinese words without spaces. + std::string text; + + // Decoded results at the token level. + // For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; +}; struct OfflineFeatureExtractorConfig { // Sampling rate used by the feature extractor. If it is different from diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 88c484c2..560adfcb 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -4,6 +4,11 @@ pybind11_add_module(_sherpa_onnx display.cc endpoint.cc features.cc + offline-model-config.cc + offline-paraformer-model-config.cc + offline-recognizer.cc + offline-stream.cc + offline-transducer-model-config.cc online-recognizer.cc online-stream.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc new file mode 100644 index 00000000..522e16d6 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -0,0 +1,36 @@ +// sherpa-onnx/python/csrc/offline-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/offline-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" + +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineModelConfig(py::module *m) { + PybindOfflineTransducerModelConfig(m); + PybindOfflineParaformerModelConfig(m); + + using PyClass = OfflineModelConfig; + py::class_(*m, "OfflineModelConfig") + .def(py::init(), + py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"), + py::arg("num_threads"), py::arg("debug") = false) + .def_readwrite("transducer", &PyClass::transducer) + .def_readwrite("paraformer", &PyClass::paraformer) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-model-config.h b/sherpa-onnx/python/csrc/offline-model-config.h new file mode 100644 index 00000000..61fab88e --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc new file mode 100644 index 00000000..ca5c7d24 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/offline-paraformer-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" + + +#include +#include + +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineParaformerModelConfig(py::module *m) { + using PyClass = OfflineParaformerModelConfig; + py::class_(*m, "OfflineParaformerModelConfig") + .def(py::init(), + py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-paraformer-model-config.h b/sherpa-onnx/python/csrc/offline-paraformer-model-config.h new file mode 100644 index 00000000..d4862fa3 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-paraformer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-paraformer-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineParaformerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc new file mode 100644 index 00000000..7365acf1 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/python/csrc/offline-recognizer.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/offline-recognizer.h" + +#include +#include + +#include "sherpa-onnx/csrc/offline-recognizer.h" + +namespace sherpa_onnx { + + + +static void PybindOfflineRecognizerConfig(py::module *m) { + using PyClass = OfflineRecognizerConfig; + py::class_(*m, "OfflineRecognizerConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("decoding_method")) + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("decoding_method", &PyClass::decoding_method) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflineRecognizer(py::module *m) { + PybindOfflineRecognizerConfig(m); + + using PyClass = OfflineRecognizer; + py::class_(*m, "OfflineRecognizer") + .def(py::init(), py::arg("config")) + .def("create_stream", &PyClass::CreateStream) + .def("decode_stream", &PyClass::DecodeStream) + .def("decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-recognizer.h b/sherpa-onnx/python/csrc/offline-recognizer.h new file mode 100644 index 00000000..a9ac019c --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-recognizer.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-recognizer.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineRecognizer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ diff --git a/sherpa-onnx/python/csrc/offline-stream.cc b/sherpa-onnx/python/csrc/offline-stream.cc new file mode 100644 index 00000000..be989aca --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-stream.cc @@ -0,0 +1,61 @@ +// sherpa-onnx/python/csrc/offline-stream.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/offline-stream.h" + +#include "sherpa-onnx/csrc/offline-stream.h" + +namespace sherpa_onnx { + +constexpr const char *kAcceptWaveformUsage = R"( +Process audio samples. + +Args: + sample_rate: + Sample rate of the input samples. If it is different from the one + expected by the model, we will do resampling inside. + waveform: + A 1-D float32 tensor containing audio samples. It must be normalized + to the range [-1, 1]. +)"; + +static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT + using PyClass = OfflineRecognitionResult; + py::class_(*m, "OfflineRecognitionResult") + .def_property_readonly("text", + [](const PyClass &self) { return self.text; }) + .def_property_readonly("tokens", + [](const PyClass &self) { return self.tokens; }) + .def_property_readonly( + "timestamps", [](const PyClass &self) { return self.timestamps; }); +} + + +static void PybindOfflineFeatureExtractorConfig(py::module *m) { + using PyClass = OfflineFeatureExtractorConfig; + py::class_(*m, "OfflineFeatureExtractorConfig") + .def(py::init(), py::arg("sampling_rate") = 16000, + py::arg("feature_dim") = 80) + .def_readwrite("sampling_rate", &PyClass::sampling_rate) + .def_readwrite("feature_dim", &PyClass::feature_dim) + .def("__str__", &PyClass::ToString); +} + + +void PybindOfflineStream(py::module *m) { + PybindOfflineFeatureExtractorConfig(m); + PybindOfflineRecognitionResult(m); + + using PyClass = OfflineStream; + py::class_(*m, "OfflineStream") + .def( + "accept_waveform", + [](PyClass &self, float sample_rate, py::array_t waveform) { + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); + }, + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) + .def_property_readonly("result", &PyClass::GetResult); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-stream.h b/sherpa-onnx/python/csrc/offline-stream.h new file mode 100644 index 00000000..367c9e10 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-stream.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-stream.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineStream(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ diff --git a/sherpa-onnx/python/csrc/offline-transducer-model-config.cc b/sherpa-onnx/python/csrc/offline-transducer-model-config.cc new file mode 100644 index 00000000..9d599903 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-transducer-model-config.cc @@ -0,0 +1,28 @@ +// sherpa-onnx/python/csrc/offline-transducer-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" + + +#include +#include + +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineTransducerModelConfig(py::module *m) { + using PyClass = OfflineTransducerModelConfig; + py::class_(*m, "OfflineTransducerModelConfig") + .def(py::init(), + py::arg("encoder_filename"), py::arg("decoder_filename"), + py::arg("joiner_filename")) + .def_readwrite("encoder_filename", &PyClass::encoder_filename) + .def_readwrite("decoder_filename", &PyClass::decoder_filename) + .def_readwrite("joiner_filename", &PyClass::joiner_filename) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-transducer-model-config.h b/sherpa-onnx/python/csrc/offline-transducer-model-config.h new file mode 100644 index 00000000..ce8d333c --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-transducer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-transducer-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineTransducerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 5e5886d4..b235d47f 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -11,10 +11,17 @@ #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-recognizer.h" +#include "sherpa-onnx/python/csrc/offline-stream.h" +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" + namespace sherpa_onnx { PYBIND11_MODULE(_sherpa_onnx, m) { m.doc() = "pybind11 binding of sherpa-onnx"; + PybindFeatures(&m); PybindOnlineTransducerModelConfig(&m); PybindOnlineStream(&m); @@ -22,6 +29,10 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOnlineRecognizer(&m); PybindDisplay(&m); + + PybindOfflineStream(&m); + PybindOfflineModelConfig(&m); + PybindOfflineRecognizer(&m); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 13a98c53..80938304 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,3 +1,4 @@ from _sherpa_onnx import Display from .online_recognizer import OnlineRecognizer +from .offline_recognizer import OfflineRecognizer diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py new file mode 100644 index 00000000..f84ad977 --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023 by manyeyes +from pathlib import Path +from typing import List + +from _sherpa_onnx import ( + OfflineFeatureExtractorConfig, + OfflineRecognizer as _Recognizer, + OfflineRecognizerConfig, + OfflineStream, + OfflineModelConfig, + OfflineTransducerModelConfig, + OfflineParaformerModelConfig, +) + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class OfflineRecognizer(object): + """A class for offline speech recognition.""" + + @classmethod + def from_transducer( + cls, + encoder: str, + decoder: str, + joiner: str, + tokens: str, + num_threads: int, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search, modified_beam_search. + debug: + True to show debug messages. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + transducer=OfflineTransducerModelConfig( + encoder_filename=encoder, + decoder_filename=decoder, + joiner_filename=joiner + ), + paraformer=OfflineParaformerModelConfig( + model="" + ), + tokens=tokens, + num_threads=num_threads, + debug=debug + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + return self + + @classmethod + def from_paraformer( + cls, + paraformer: str, + tokens: str, + num_threads: int, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + paraformer: + Path to ``paraformer.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search, modified_beam_search. + debug: + True to show debug messages. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + transducer=OfflineTransducerModelConfig( + encoder_filename="", + decoder_filename="", + joiner_filename="" + ), + paraformer=OfflineParaformerModelConfig( + model=paraformer + ), + tokens=tokens, + num_threads=num_threads, + debug=debug + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + return self + + def create_stream(self): + return self.recognizer.create_stream() + + def decode_stream(self, s: OfflineStream): + self.recognizer.decode_stream(s) + + def decode_streams(self, ss: List[OfflineStream]): + self.recognizer.decode_streams(ss) +