From ea09d5fbc5888050c672f0746aa464a8e52b6d1a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 19 Feb 2023 19:36:03 +0800 Subject: [PATCH] Add Python API (#31) --- .flake8 | 8 ++ .github/scripts/test-python.sh | 33 +++++ .github/workflows/macos.yaml | 2 +- .github/workflows/run-python-test.yaml | 62 +++++++++ .gitignore | 3 + CMakeLists.txt | 31 ++++- cmake/__init__.py | 0 cmake/cmake_extension.py | 124 ++++++++++++++++++ cmake/kaldi-native-fbank.cmake | 15 ++- cmake/onnxruntime.cmake | 13 ++ cmake/pybind11.cmake | 38 ++++++ python-api-examples/decode-file.py | 73 +++++++++++ setup.py | 75 +++++++++++ sherpa-onnx/CMakeLists.txt | 3 + sherpa-onnx/csrc/CMakeLists.txt | 16 ++- sherpa-onnx/csrc/features.cc | 2 +- sherpa-onnx/csrc/features.h | 2 +- .../csrc/online-lstm-transducer-model.cc | 4 +- .../csrc/online-lstm-transducer-model.h | 4 +- sherpa-onnx/csrc/online-recognizer.cc | 2 +- sherpa-onnx/csrc/online-recognizer.h | 7 + sherpa-onnx/csrc/online-transducer-decoder.h | 2 +- ...online-transducer-greedy-search-decoder.cc | 2 +- .../online-transducer-greedy-search-decoder.h | 2 +- .../csrc/online-transducer-model-config.cc | 2 +- .../csrc/online-transducer-model-config.h | 13 +- sherpa-onnx/csrc/online-transducer-model.cc | 2 +- sherpa-onnx/csrc/online-transducer-model.h | 4 +- sherpa-onnx/csrc/onnx-utils.cc | 2 +- sherpa-onnx/csrc/onnx-utils.h | 2 +- sherpa-onnx/csrc/show-onnx-info.cc | 22 ---- sherpa-onnx/csrc/symbol-table.h | 2 +- sherpa-onnx/csrc/wave-reader.cc | 2 +- sherpa-onnx/csrc/wave-reader.h | 2 +- sherpa-onnx/python/CMakeLists.txt | 5 + sherpa-onnx/python/csrc/CMakeLists.txt | 29 ++++ sherpa-onnx/python/csrc/features.cc | 23 ++++ sherpa-onnx/python/csrc/features.h | 16 +++ sherpa-onnx/python/csrc/online-recognizer.cc | 49 +++++++ sherpa-onnx/python/csrc/online-recognizer.h | 16 +++ sherpa-onnx/python/csrc/online-stream.cc | 21 +++ sherpa-onnx/python/csrc/online-stream.h | 16 +++ .../csrc/online-transducer-model-config.cc | 29 ++++ .../csrc/online-transducer-model-config.h | 16 +++ sherpa-onnx/python/csrc/sherpa-onnx.cc | 22 ++++ sherpa-onnx/python/csrc/sherpa-onnx.h | 14 ++ sherpa-onnx/python/sherpa_onnx/__init__.py | 8 ++ .../python/sherpa_onnx/online_recognizer.py | 96 ++++++++++++++ sherpa-onnx/python/tests/CMakeLists.txt | 27 ++++ .../tests/test_feature_extractor_config.py | 29 ++++ .../test_online_transducer_model_config.py | 32 +++++ 51 files changed, 967 insertions(+), 57 deletions(-) create mode 100644 .flake8 create mode 100755 .github/scripts/test-python.sh create mode 100644 .github/workflows/run-python-test.yaml create mode 100644 cmake/__init__.py create mode 100644 cmake/cmake_extension.py create mode 100644 cmake/pybind11.cmake create mode 100644 python-api-examples/decode-file.py create mode 100644 setup.py delete mode 100644 sherpa-onnx/csrc/show-onnx-info.cc create mode 100644 sherpa-onnx/python/CMakeLists.txt create mode 100644 sherpa-onnx/python/csrc/CMakeLists.txt create mode 100644 sherpa-onnx/python/csrc/features.cc create mode 100644 sherpa-onnx/python/csrc/features.h create mode 100644 sherpa-onnx/python/csrc/online-recognizer.cc create mode 100644 sherpa-onnx/python/csrc/online-recognizer.h create mode 100644 sherpa-onnx/python/csrc/online-stream.cc create mode 100644 sherpa-onnx/python/csrc/online-stream.h create mode 100644 sherpa-onnx/python/csrc/online-transducer-model-config.cc create mode 100644 sherpa-onnx/python/csrc/online-transducer-model-config.h create mode 100644 sherpa-onnx/python/csrc/sherpa-onnx.cc create mode 100644 sherpa-onnx/python/csrc/sherpa-onnx.h create mode 100644 sherpa-onnx/python/sherpa_onnx/__init__.py create mode 100644 sherpa-onnx/python/sherpa_onnx/online_recognizer.py create mode 100644 sherpa-onnx/python/tests/CMakeLists.txt create mode 100644 sherpa-onnx/python/tests/test_feature_extractor_config.py create mode 100644 sherpa-onnx/python/tests/test_online_transducer_model_config.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..87510d70 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +show-source=true +statistics=true +max-line-length = 80 + +exclude = + .git, + ./cmake, diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh new file mode 100755 index 00000000..c5d9accb --- /dev/null +++ b/.github/scripts/test-python.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 + +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +popd + +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)") + +echo "sherpa_onnx version: $sherpa_onnx_version" + +pwd +ls -lh + +ls -lh $repo + +python3 python-api-examples/decode-file.py diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index d6897552..7f53317d 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -47,7 +47,7 @@ jobs: cd build cmake -D CMAKE_BUILD_TYPE=Release .. - - name: Build sherpa for macos + - name: Build sherpa-onnx for macos shell: bash run: | cd build diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml new file mode 100644 index 00000000..07fd6f14 --- /dev/null +++ b/.github/workflows/run-python-test.yaml @@ -0,0 +1,62 @@ +name: run-python-test + +on: + push: + branches: + - master + paths: + - '.github/workflows/run-python-test.yaml' + - '.github/scripts/test-python.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + pull_request: + branches: + - master + paths: + - '.github/workflows/run-python-test.yaml' + - '.github/scripts/test-python.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + +concurrency: + group: run-python-test-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + run-python-test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + shell: bash + run: | + python3 -m pip install --upgrade pip numpy + + - name: Install sherpa-onnx + shell: bash + run: | + python3 setup.py install + + - name: Test sherpa-onnx + shell: bash + run: | + .github/scripts/test-python.sh diff --git a/.gitignore b/.gitignore index 066c5948..0c1f4139 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ onnxruntime-* icefall-* run.sh sherpa-onnx-* +__pycache__ +dist/ +sherpa_onnx.egg-info/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 266e3532..27e43708 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,19 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.0") +set(SHERPA_ONNX_VERSION "1.1") + +# Disable warning about +# +# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is +# not set. +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) +option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) +option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -20,13 +32,20 @@ endif() set(CMAKE_INSTALL_RPATH ${SHERPA_ONNX_RPATH_ORIGIN}) set(CMAKE_BUILD_RPATH ${SHERPA_ONNX_RPATH_ORIGIN}) -set(BUILD_SHARED_LIBS ON) +if(WIN32 AND BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to OFF for windows") + set(BUILD_SHARED_LIBS OFF) +endif() if(NOT CMAKE_BUILD_TYPE) message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") set(CMAKE_BUILD_TYPE Release) endif() + message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") +message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") +message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") +message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_CXX_EXTENSIONS OFF) @@ -37,4 +56,12 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) include(kaldi-native-fbank) include(onnxruntime) +if(SHERPA_ONNX_ENABLE_PYTHON) + include(pybind11) +endif() + +if(SHERPA_ONNX_ENABLE_TESTS) + enable_testing() +endif() + add_subdirectory(sherpa-onnx) diff --git a/cmake/__init__.py b/cmake/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py new file mode 100644 index 00000000..f5c3ccc0 --- /dev/null +++ b/cmake/cmake_extension.py @@ -0,0 +1,124 @@ +# cmake/cmake_extension.py +# Copyright (c) 2023 Xiaomi Corporation +# +# flake8: noqa + +import os +import platform +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + + +def is_for_pypi(): + ans = os.environ.get("SHERPA_ONNX_IS_FOR_PYPI", None) + return ans is not None + + +def is_macos(): + return platform.system() == "Darwin" + + +def is_windows(): + return platform.system() == "Windows" + + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # sherpa-xxx-pyxx-none-any.whl + if is_for_pypi() and not is_macos(): + self.root_is_pure = True + else: + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + install_dir = Path(self.build_lib).resolve() / "sherpa_onnx" + + sherpa_onnx_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("SHERPA_ONNX_CMAKE_ARGS", "") + make_args = os.environ.get("SHERPA_ONNX_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} " + if not is_windows(): + extra_cmake_args += " -DBUILD_SHARED_LIBS=ON " + else: + extra_cmake_args += " -DBUILD_SHARED_LIBS=OFF " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_PYTHON=ON " + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + cmake_args += extra_cmake_args + + if is_windows(): + build_cmd = f""" + cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir} + cmake --build {self.build_temp} --target install --config Release -- -m + """ + print(f"build command is:\n{build_cmd}") + ret = os.system( + f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}" + ) + if ret != 0: + raise Exception("Failed to configure sherpa") + + ret = os.system( + f"cmake --build {self.build_temp} --target install --config Release -- -m" # noqa + ) + if ret != 0: + raise Exception("Failed to build and install sherpa") + else: + if make_args == "" and system_make_args == "": + print("for fast compilation, run:") + print('export SHERPA_ONNX_MAKE_ARGS="-j"; python setup.py install') + print('Setting make_args to "-j4"') + make_args = "-j4" + + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {sherpa_onnx_dir} + + make {make_args} install/strip + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild sherpa-onnx failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/k2-fsa/sherpa-onnx/issues/new\n" # noqa + ) diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index 3450a9f1..5faa0a59 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,8 +1,8 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -11,11 +11,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz - /tmp/kaldi-native-fbank-1.12.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz + /tmp/kaldi-native-fbank-1.13.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz ) foreach(f IN LISTS possible_file_locations) @@ -44,6 +44,7 @@ function(download_kaldi_native_fbank) INTERFACE ${kaldi_native_fbank_SOURCE_DIR}/ ) + install(TARGETS kaldi-native-fbank-core DESTINATION lib) endfunction() download_kaldi_native_fbank() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9325b536..a62e42a1 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -85,6 +85,7 @@ function(download_onnxruntime) message(STATUS "location_onnxruntime: ${location_onnxruntime}") add_library(onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${location_onnxruntime} INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" @@ -100,6 +101,18 @@ function(download_onnxruntime) ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} ) endif() + + + if(UNIX AND NOT APPLE) + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*") + elseif(APPLE) + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*dylib") + elseif(WIN32) + file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/*.dll") + endif() + + message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") + install(FILES ${onnxruntime_lib_files} DESTINATION lib) endfunction() download_onnxruntime() diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake new file mode 100644 index 00000000..a6e9c449 --- /dev/null +++ b/cmake/pybind11.cmake @@ -0,0 +1,38 @@ +function(download_pybind11) + include(FetchContent) + + set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.2.tar.gz") + set(pybind11_HASH "SHA256=93bd1e625e43e03028a3ea7389bba5d3f9f2596abc074b068e70f4ef9b1314ae") + + # If you don't have access to the Internet, + # please pre-download pybind11 + set(possible_file_locations + $ENV{HOME}/Downloads/pybind11-2.10.2.tar.gz + ${PROJECT_SOURCE_DIR}/pybind11-2.10.2.tar.gz + ${PROJECT_BINARY_DIR}/pybind11-2.10.2.tar.gz + /tmp/pybind11-2.10.2.tar.gz + /star-fj/fangjun/download/github/pybind11-2.10.2.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(pybind11_URL "file://${f}") + break() + endif() + endforeach() + + FetchContent_Declare(pybind11 + URL ${pybind11_URL} + URL_HASH ${pybind11_HASH} + ) + + FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) + message(STATUS "Downloading pybind11 from ${pybind11_URL}") + FetchContent_Populate(pybind11) + endif() + message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") + add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_pybind11() diff --git a/python-api-examples/decode-file.py b/python-api-examples/decode-file.py new file mode 100644 index 00000000..5bc2288e --- /dev/null +++ b/python-api-examples/decode-file.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +""" +This file demonstrates how to use sherpa-onnx Python API to recognize +a single file. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download the pre-trained models +used in this file. +""" +import wave +import time + +import numpy as np +import sherpa_onnx + + +def main(): + sample_rate = 16000 + num_threads = 4 + recognizer = sherpa_onnx.OnlineRecognizer( + tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt", + encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx", + decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx", + joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", + num_threads=num_threads, + sample_rate=sample_rate, + feature_dim=80, + ) + filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav" + with wave.open(filename) as f: + assert f.getframerate() == sample_rate, f.getframerate() + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + + duration = len(samples_float32) / sample_rate + + start_time = time.time() + print("Started!") + + stream = recognizer.create_stream() + + stream.accept_waveform(sample_rate, samples_float32) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + stream.accept_waveform(sample_rate, tail_paddings) + + stream.input_finished() + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + print(recognizer.get_result(stream)) + + print("Done!") + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / duration + print(f"num_threads: {num_threads}") + print(f"Wave duration: {duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..c8cf295f --- /dev/null +++ b/setup.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +import os +import re +import sys +from pathlib import Path + +import setuptools + +from cmake.cmake_extension import ( + BuildExtension, + bdist_wheel, + cmake_extension, + is_windows, +) + + +def read_long_description(): + with open("README.md", encoding="utf8") as f: + readme = f.read() + return readme + + +def get_package_version(): + with open("CMakeLists.txt") as f: + content = f.read() + + match = re.search(r"set\(SHERPA_ONNX_VERSION (.*)\)", content) + latest_version = match.group(1).strip('"') + return latest_version + + +package_name = "sherpa-onnx" + +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + +install_requires = [ + "numpy", +] + +setuptools.setup( + name=package_name, + python_requires=">=3.6", + install_requires=install_requires, + version=get_package_version(), + author="The sherpa-onnx development team", + author_email="dpovey@gmail.com", + package_dir={ + "sherpa_onnx": "sherpa-onnx/python/sherpa_onnx", + }, + packages=["sherpa_onnx"], + url="https://github.com/k2-fsa/sherpa-onnx", + long_description=read_long_description(), + long_description_content_type="text/markdown", + ext_modules=[cmake_extension("_sherpa_onnx")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, + zip_safe=False, + classifiers=[ + "Programming Language :: C++", + "Programming Language :: Python", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + license="Apache licensed, as found in the LICENSE file", +) + +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "r") as f: + lines = f.readlines() + +with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "w") as f: + for line in lines: + if "__version__" in line: + # skip __version__ = "x.x.x" + continue + f.write(line) diff --git a/sherpa-onnx/CMakeLists.txt b/sherpa-onnx/CMakeLists.txt index 86735ca2..257eaba1 100644 --- a/sherpa-onnx/CMakeLists.txt +++ b/sherpa-onnx/CMakeLists.txt @@ -1 +1,4 @@ add_subdirectory(csrc) +if(SHERPA_ONNX_ENABLE_PYTHON) + add_subdirectory(python) +endif() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index c89c1832..45c4eef2 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,6 +1,6 @@ include_directories(${CMAKE_SOURCE_DIR}) -add_executable(sherpa-onnx +add_library(sherpa-onnx-core features.cc online-lstm-transducer-model.cc online-recognizer.cc @@ -9,15 +9,21 @@ add_executable(sherpa-onnx online-transducer-model-config.cc online-transducer-model.cc onnx-utils.cc - sherpa-onnx.cc symbol-table.cc wave-reader.cc ) -target_link_libraries(sherpa-onnx +target_link_libraries(sherpa-onnx-core onnxruntime kaldi-native-fbank-core ) -add_executable(sherpa-onnx-show-info show-onnx-info.cc) -target_link_libraries(sherpa-onnx-show-info onnxruntime) +add_executable(sherpa-onnx sherpa-onnx.cc) + +target_link_libraries(sherpa-onnx sherpa-onnx-core) +if(NOT WIN32) + target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") +endif() + +install(TARGETS sherpa-onnx-core DESTINATION lib) +install(TARGETS sherpa-onnx DESTINATION bin) diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index da7074f8..520ddfb7 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/features.cc +// sherpa-onnx/csrc/features.cc // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 59f07188..bbe6e628 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -1,4 +1,4 @@ -// sherpa/csrc/features.h +// sherpa-onnx/csrc/features.h // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 20876791..022a3376 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/online-lstm-transducer-model.cc +// sherpa-onnx/csrc/online-lstm-transducer-model.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" @@ -232,7 +232,7 @@ std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { std::pair> OnlineLstmTransducerModel::RunEncoder(Ort::Value features, - std::vector &states) { + std::vector states) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index 5fc23260..c24bfca4 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -1,4 +1,4 @@ -// sherpa/csrc/online-lstm-transducer-model.h +// sherpa-onnx/csrc/online-lstm-transducer-model.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ @@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { std::vector GetEncoderInitStates() override; std::pair> RunEncoder( - Ort::Value features, std::vector &states) override; + Ort::Value features, std::vector states) override; Ort::Value BuildDecoderInput( const std::vector &results) override; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 2f0fdbf4..29aeca16 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -98,7 +98,7 @@ class OnlineRecognizer::Impl { auto states = model_->StackStates(states_vec); - auto pair = model_->RunEncoder(std::move(x), states); + auto pair = model_->RunEncoder(std::move(x), std::move(states)); decoder_->Decode(std::move(pair.first), &results); diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index f9622452..0d85d38c 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -23,6 +23,13 @@ struct OnlineRecognizerConfig { OnlineTransducerModelConfig model_config; std::string tokens; + OnlineRecognizerConfig() = default; + + OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, + const OnlineTransducerModelConfig &model_config, + const std::string &tokens) + : feat_config(feat_config), model_config(model_config), tokens(tokens) {} + std::string ToString() const; }; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 1c72fd1b..92f4eeaa 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-decoder.h +// sherpa-onnx/csrc/online-transducer-decoder.h // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index e628cc7c..1776805c 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-greedy-search-decoder.cc +// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index 23b507f2..f7fa7ddf 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-greedy-search-decoder.h +// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index 5fcb09ee..c3ed6a54 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-model-config.cc +// sherpa-onnx/csrc/online-transducer-model-config.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-transducer-model-config.h" diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index ca2e5dbc..34af8547 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-model-config.h +// sherpa-onnx/csrc/online-transducer-model-config.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ @@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig { int32_t num_threads; bool debug = false; + OnlineTransducerModelConfig() = default; + OnlineTransducerModelConfig(const std::string &encoder_filename, + const std::string &decoder_filename, + const std::string &joiner_filename, + int32_t num_threads, bool debug) + : encoder_filename(encoder_filename), + decoder_filename(decoder_filename), + joiner_filename(joiner_filename), + num_threads(num_threads), + debug(debug) {} + std::string ToString() const; }; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 27af24e7..14eaf16b 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-model.cc +// sherpa-onnx/csrc/online-transducer-model.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/online-transducer-model.h" diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index 8f33b818..baed186a 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -1,4 +1,4 @@ -// sherpa/csrc/online-transducer-model.h +// sherpa-onnx/csrc/online-transducer-model.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ @@ -59,7 +59,7 @@ class OnlineTransducerModel { */ virtual std::pair> RunEncoder( Ort::Value features, - std::vector &states) = 0; // NOLINT + std::vector states) = 0; // NOLINT virtual Ort::Value BuildDecoderInput( const std::vector &results) = 0; diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 9105fab0..8a9cf301 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/onnx-utils.cc +// sherpa-onnx/csrc/onnx-utils.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 2307722b..38a7b143 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -1,4 +1,4 @@ -// sherpa/csrc/onnx-utils.h +// sherpa-onnx/csrc/onnx-utils.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/show-onnx-info.cc b/sherpa-onnx/csrc/show-onnx-info.cc deleted file mode 100644 index ef2766c7..00000000 --- a/sherpa-onnx/csrc/show-onnx-info.cc +++ /dev/null @@ -1,22 +0,0 @@ -// sherpa-onnx/csrc/show-onnx-info.cc -// -// Copyright (c) 2022-2023 Xiaomi Corporation - -#include -#include - -#include "onnxruntime_cxx_api.h" // NOLINT - -int main() { - std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n"; - std::vector providers = Ort::GetAvailableProviders(); - std::ostringstream os; - os << "Available providers: "; - std::string sep = ""; - for (const auto &p : providers) { - os << sep << p; - sep = ", "; - } - std::cout << os.str() << "\n"; - return 0; -} diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index fdcde41e..0e1b74a9 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/symbol-table.cc +// sherpa-onnx/csrc/symbol-table.h // // Copyright (c) 2022-2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index cdc80f81..2223641d 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -1,4 +1,4 @@ -// sherpa/csrc/wave-reader.cc +// sherpa-onnx/csrc/wave-reader.cc // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h index fb5c68c1..dfec9807 100644 --- a/sherpa-onnx/csrc/wave-reader.h +++ b/sherpa-onnx/csrc/wave-reader.h @@ -1,4 +1,4 @@ -// sherpa/csrc/wave-reader.h +// sherpa-onnx/csrc/wave-reader.h // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/python/CMakeLists.txt b/sherpa-onnx/python/CMakeLists.txt new file mode 100644 index 00000000..f433c471 --- /dev/null +++ b/sherpa-onnx/python/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(csrc) + +if(SHERPA_ONNX_ENABLE_TESTS) + add_subdirectory(tests) +endif() diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt new file mode 100644 index 00000000..e2efa3f8 --- /dev/null +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -0,0 +1,29 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +pybind11_add_module(_sherpa_onnx + features.cc + online-transducer-model-config.cc + sherpa-onnx.cc + online-stream.cc + online-recognizer.cc +) + +if(APPLE) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR + ) + message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}") + target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}") +endif() + +if(NOT WIN32) + target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/sherpa_onnx/lib") +endif() + +target_link_libraries(_sherpa_onnx PRIVATE sherpa-onnx-core) + +install(TARGETS _sherpa_onnx + DESTINATION ../ +) diff --git a/sherpa-onnx/python/csrc/features.cc b/sherpa-onnx/python/csrc/features.cc new file mode 100644 index 00000000..6458f5cc --- /dev/null +++ b/sherpa-onnx/python/csrc/features.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/features.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/features.h" + +#include "sherpa-onnx/csrc/features.h" + +namespace sherpa_onnx { + +static void PybindFeatureExtractorConfig(py::module *m) { + using PyClass = FeatureExtractorConfig; + py::class_(*m, "FeatureExtractorConfig") + .def(py::init(), py::arg("sampling_rate") = 16000, + py::arg("feature_dim") = 80) + .def_readwrite("sampling_rate", &PyClass::sampling_rate) + .def_readwrite("feature_dim", &PyClass::feature_dim) + .def("__str__", &PyClass::ToString); +} + +void PybindFeatures(py::module *m) { PybindFeatureExtractorConfig(m); } + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/features.h b/sherpa-onnx/python/csrc/features.h new file mode 100644 index 00000000..2e599d79 --- /dev/null +++ b/sherpa-onnx/python/csrc/features.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/features.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ +#define SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindFeatures(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc new file mode 100644 index 00000000..52c74f23 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -0,0 +1,49 @@ +// sherpa-onnx/python/csrc/online-recongizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-recognizer.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-recognizer.h" + +namespace sherpa_onnx { + +static void PybindOnlineRecognizerResult(py::module *m) { + using PyClass = OnlineRecognizerResult; + py::class_(*m, "OnlineRecognizerResult") + .def_property_readonly("text", [](PyClass &self) { return self.text; }); +} + +static void PybindOnlineRecognizerConfig(py::module *m) { + using PyClass = OnlineRecognizerConfig; + py::class_(*m, "OnlineRecognizerConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), py::arg("tokens")) + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("tokens", &PyClass::tokens) + .def("__str__", &PyClass::ToString); +} + +void PybindOnlineRecognizer(py::module *m) { + PybindOnlineRecognizerResult(m); + PybindOnlineRecognizerConfig(m); + + using PyClass = OnlineRecognizer; + py::class_(*m, "OnlineRecognizer") + .def(py::init(), py::arg("config")) + .def("create_stream", &PyClass::CreateStream) + .def("is_ready", &PyClass::IsReady) + .def("decode_stream", &PyClass::DecodeStream) + .def("decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }) + .def("get_result", &PyClass::GetResult); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-recognizer.h b/sherpa-onnx/python/csrc/online-recognizer.h new file mode 100644 index 00000000..0e652c7f --- /dev/null +++ b/sherpa-onnx/python/csrc/online-recognizer.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-recongizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineRecognizer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ diff --git a/sherpa-onnx/python/csrc/online-stream.cc b/sherpa-onnx/python/csrc/online-stream.cc new file mode 100644 index 00000000..06a46a59 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-stream.cc @@ -0,0 +1,21 @@ +// sherpa-onnx/python/csrc/online-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-stream.h" + +#include "sherpa-onnx/csrc/online-stream.h" + +namespace sherpa_onnx { + +void PybindOnlineStream(py::module *m) { + using PyClass = OnlineStream; + py::class_(*m, "OnlineStream") + .def("accept_waveform", + [](PyClass &self, float sample_rate, py::array_t waveform) { + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); + }) + .def("input_finished", &PyClass::InputFinished); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-stream.h b/sherpa-onnx/python/csrc/online-stream.h new file mode 100644 index 00000000..9d88f0e8 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-stream.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineStream(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ diff --git a/sherpa-onnx/python/csrc/online-transducer-model-config.cc b/sherpa-onnx/python/csrc/online-transducer-model-config.cc new file mode 100644 index 00000000..16ff3133 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-transducer-model-config.cc @@ -0,0 +1,29 @@ +// sherpa-onnx/python/csrc/online-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +#include + +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineTransducerModelConfig(py::module *m) { + using PyClass = OnlineTransducerModelConfig; + py::class_(*m, "OnlineTransducerModelConfig") + .def(py::init(), + py::arg("encoder_filename"), py::arg("decoder_filename"), + py::arg("joiner_filename"), py::arg("num_threads"), + py::arg("debug") = false) + .def_readwrite("encoder_filename", &PyClass::encoder_filename) + .def_readwrite("decoder_filename", &PyClass::decoder_filename) + .def_readwrite("joiner_filename", &PyClass::joiner_filename) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-transducer-model-config.h b/sherpa-onnx/python/csrc/online-transducer-model-config.h new file mode 100644 index 00000000..bbe13559 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-transducer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineTransducerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc new file mode 100644 index 00000000..cca04c09 --- /dev/null +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/sherpa-onnx.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +#include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/online-recognizer.h" +#include "sherpa-onnx/python/csrc/online-stream.h" +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +PYBIND11_MODULE(_sherpa_onnx, m) { + m.doc() = "pybind11 binding of sherpa-onnx"; + PybindFeatures(&m); + PybindOnlineTransducerModelConfig(&m); + PybindOnlineStream(&m); + PybindOnlineRecognizer(&m); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.h b/sherpa-onnx/python/csrc/sherpa-onnx.h new file mode 100644 index 00000000..d0aa9f96 --- /dev/null +++ b/sherpa-onnx/python/csrc/sherpa-onnx.h @@ -0,0 +1,14 @@ +// sherpa-onnx/python/csrc/sherpa-onnx.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +#endif // SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py new file mode 100644 index 00000000..60a03cc2 --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -0,0 +1,8 @@ +from _sherpa_onnx import ( + FeatureExtractorConfig, + OnlineRecognizerConfig, + OnlineStream, + OnlineTransducerModelConfig, +) + +from .online_recognizer import OnlineRecognizer diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py new file mode 100644 index 00000000..90ba196f --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -0,0 +1,96 @@ +from pathlib import Path +from typing import List + +from _sherpa_onnx import ( + OnlineStream, + OnlineTransducerModelConfig, + FeatureExtractorConfig, + OnlineRecognizerConfig, +) +from _sherpa_onnx import OnlineRecognizer as _Recognizer + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class OnlineRecognizer(object): + """A class for streaming speech recognition.""" + + def __init__( + self, + tokens: str, + encoder: str, + decoder: str, + joiner: str, + num_threads: int = 4, + sample_rate: float = 16000, + feature_dim: int = 80, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + """ + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + _assert_file_exists(joiner) + + assert num_threads > 0, num_threads + + model_config = OnlineTransducerModelConfig( + encoder_filename=encoder, + decoder_filename=decoder, + joiner_filename=joiner, + num_threads=num_threads, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + tokens=tokens, + ) + + self.recognizer = _Recognizer(recognizer_config) + + def create_stream(self): + return self.recognizer.create_stream() + + def decode_stream(self, s: OnlineStream): + self.recognizer.decode_stream(s) + + def decode_streams(self, ss: List[OnlineStream]): + self.recognizer.decode_streams(ss) + + def is_ready(self, s: OnlineStream) -> bool: + return self.recognizer.is_ready(s) + + def get_result(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).text diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt new file mode 100644 index 00000000..c53a09f1 --- /dev/null +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -0,0 +1,27 @@ +function(sherpa_onnx_add_py_test source) + get_filename_component(name ${source} NAME_WE) + set(name "${name}_py") + + add_test(NAME ${name} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CURRENT_SOURCE_DIR}/${source}" + ) + + get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) + + set_property(TEST ${name} + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$:$ENV{PYTHONPATH}" + ) +endfunction() + +# please sort the files in alphabetic order +set(py_test_files + test_feature_extractor_config.py + test_online_transducer_model_config.py +) + +foreach(source IN LISTS py_test_files) + sherpa_onnx_add_py_test(${source}) +endforeach() + diff --git a/sherpa-onnx/python/tests/test_feature_extractor_config.py b/sherpa-onnx/python/tests/test_feature_extractor_config.py new file mode 100644 index 00000000..e12f808a --- /dev/null +++ b/sherpa-onnx/python/tests/test_feature_extractor_config.py @@ -0,0 +1,29 @@ +# sherpa-onnx/python/tests/test_feature_extractor_config.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_feature_extractor_config_py + +import unittest + +import sherpa_onnx + + +class TestFeatureExtractorConfig(unittest.TestCase): + def test_default_constructor(self): + config = sherpa_onnx.FeatureExtractorConfig() + assert config.sampling_rate == 16000, config.sampling_rate + assert config.feature_dim == 80, config.feature_dim + print(config) + + def test_constructor(self): + config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) + assert config.sampling_rate == 8000, config.sampling_rate + assert config.feature_dim == 40, config.feature_dim + print(config) + + +if __name__ == "__main__": + unittest.main() diff --git a/sherpa-onnx/python/tests/test_online_transducer_model_config.py b/sherpa-onnx/python/tests/test_online_transducer_model_config.py new file mode 100644 index 00000000..1b9010db --- /dev/null +++ b/sherpa-onnx/python/tests/test_online_transducer_model_config.py @@ -0,0 +1,32 @@ +# sherpa-onnx/python/tests/test_online_transducer_model_config.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_online_transducer_model_config_py + +import unittest + +import sherpa_onnx + + +class TestOnlineTransducerModelConfig(unittest.TestCase): + def test_constructor(self): + config = sherpa_onnx.OnlineTransducerModelConfig( + encoder_filename="encoder.onnx", + decoder_filename="decoder.onnx", + joiner_filename="joiner.onnx", + num_threads=8, + debug=True, + ) + assert config.encoder_filename == "encoder.onnx", config.encoder_filename + assert config.decoder_filename == "decoder.onnx", config.decoder_filename + assert config.joiner_filename == "joiner.onnx", config.joiner_filename + assert config.num_threads == 8, config.num_threads + assert config.debug is True, config.debug + print(config) + + +if __name__ == "__main__": + unittest.main()