Add Python API (#31)
This commit is contained in:
8
.flake8
Normal file
8
.flake8
Normal file
@@ -0,0 +1,8 @@
|
||||
[flake8]
|
||||
show-source=true
|
||||
statistics=true
|
||||
max-line-length = 80
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
./cmake,
|
||||
33
.github/scripts/test-python.sh
vendored
Executable file
33
.github/scripts/test-python.sh
vendored
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17
|
||||
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
popd
|
||||
|
||||
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
|
||||
sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")
|
||||
|
||||
echo "sherpa_onnx version: $sherpa_onnx_version"
|
||||
|
||||
pwd
|
||||
ls -lh
|
||||
|
||||
ls -lh $repo
|
||||
|
||||
python3 python-api-examples/decode-file.py
|
||||
2
.github/workflows/macos.yaml
vendored
2
.github/workflows/macos.yaml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
cd build
|
||||
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||
|
||||
- name: Build sherpa for macos
|
||||
- name: Build sherpa-onnx for macos
|
||||
shell: bash
|
||||
run: |
|
||||
cd build
|
||||
|
||||
62
.github/workflows/run-python-test.yaml
vendored
Normal file
62
.github/workflows/run-python-test.yaml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: run-python-test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/run-python-test.yaml'
|
||||
- '.github/scripts/test-python.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/run-python-test.yaml'
|
||||
- '.github/scripts/test-python.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
|
||||
concurrency:
|
||||
group: run-python-test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
run-python-test:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
python3 setup.py install
|
||||
|
||||
- name: Test sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/test-python.sh
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,3 +5,6 @@ onnxruntime-*
|
||||
icefall-*
|
||||
run.sh
|
||||
sherpa-onnx-*
|
||||
__pycache__
|
||||
dist/
|
||||
sherpa_onnx.egg-info/
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
project(sherpa-onnx)
|
||||
|
||||
set(SHERPA_ONNX_VERSION "1.0")
|
||||
set(SHERPA_ONNX_VERSION "1.1")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is
|
||||
# not set.
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
|
||||
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
@@ -20,13 +32,20 @@ endif()
|
||||
set(CMAKE_INSTALL_RPATH ${SHERPA_ONNX_RPATH_ORIGIN})
|
||||
set(CMAKE_BUILD_RPATH ${SHERPA_ONNX_RPATH_ORIGIN})
|
||||
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
if(WIN32 AND BUILD_SHARED_LIBS)
|
||||
message(STATUS "Set BUILD_SHARED_LIBS to OFF for windows")
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
||||
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
@@ -37,4 +56,12 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
||||
include(kaldi-native-fbank)
|
||||
include(onnxruntime)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||
include(pybind11)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
enable_testing()
|
||||
endif()
|
||||
|
||||
add_subdirectory(sherpa-onnx)
|
||||
|
||||
0
cmake/__init__.py
Normal file
0
cmake/__init__.py
Normal file
124
cmake/cmake_extension.py
Normal file
124
cmake/cmake_extension.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# cmake/cmake_extension.py
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# flake8: noqa
|
||||
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import setuptools
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def is_for_pypi():
|
||||
ans = os.environ.get("SHERPA_ONNX_IS_FOR_PYPI", None)
|
||||
return ans is not None
|
||||
|
||||
|
||||
def is_macos():
|
||||
return platform.system() == "Darwin"
|
||||
|
||||
|
||||
def is_windows():
|
||||
return platform.system() == "Windows"
|
||||
|
||||
|
||||
try:
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
class bdist_wheel(_bdist_wheel):
|
||||
def finalize_options(self):
|
||||
_bdist_wheel.finalize_options(self)
|
||||
# In this case, the generated wheel has a name in the form
|
||||
# sherpa-xxx-pyxx-none-any.whl
|
||||
if is_for_pypi() and not is_macos():
|
||||
self.root_is_pure = True
|
||||
else:
|
||||
# The generated wheel has a name ending with
|
||||
# -linux_x86_64.whl
|
||||
self.root_is_pure = False
|
||||
|
||||
except ImportError:
|
||||
bdist_wheel = None
|
||||
|
||||
|
||||
def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
|
||||
kwargs["language"] = "c++"
|
||||
sources = []
|
||||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
|
||||
class BuildExtension(build_ext):
|
||||
def build_extension(self, ext: setuptools.extension.Extension):
|
||||
# build/temp.linux-x86_64-3.8
|
||||
os.makedirs(self.build_temp, exist_ok=True)
|
||||
|
||||
# build/lib.linux-x86_64-3.8
|
||||
os.makedirs(self.build_lib, exist_ok=True)
|
||||
|
||||
install_dir = Path(self.build_lib).resolve() / "sherpa_onnx"
|
||||
|
||||
sherpa_onnx_dir = Path(__file__).parent.parent.resolve()
|
||||
|
||||
cmake_args = os.environ.get("SHERPA_ONNX_CMAKE_ARGS", "")
|
||||
make_args = os.environ.get("SHERPA_ONNX_MAKE_ARGS", "")
|
||||
system_make_args = os.environ.get("MAKEFLAGS", "")
|
||||
|
||||
if cmake_args == "":
|
||||
cmake_args = "-DCMAKE_BUILD_TYPE=Release"
|
||||
|
||||
extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} "
|
||||
if not is_windows():
|
||||
extra_cmake_args += " -DBUILD_SHARED_LIBS=ON "
|
||||
else:
|
||||
extra_cmake_args += " -DBUILD_SHARED_LIBS=OFF "
|
||||
extra_cmake_args += " -DSHERPA_ONNX_ENABLE_PYTHON=ON "
|
||||
|
||||
if "PYTHON_EXECUTABLE" not in cmake_args:
|
||||
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
|
||||
cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
|
||||
|
||||
cmake_args += extra_cmake_args
|
||||
|
||||
if is_windows():
|
||||
build_cmd = f"""
|
||||
cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}
|
||||
cmake --build {self.build_temp} --target install --config Release -- -m
|
||||
"""
|
||||
print(f"build command is:\n{build_cmd}")
|
||||
ret = os.system(
|
||||
f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}"
|
||||
)
|
||||
if ret != 0:
|
||||
raise Exception("Failed to configure sherpa")
|
||||
|
||||
ret = os.system(
|
||||
f"cmake --build {self.build_temp} --target install --config Release -- -m" # noqa
|
||||
)
|
||||
if ret != 0:
|
||||
raise Exception("Failed to build and install sherpa")
|
||||
else:
|
||||
if make_args == "" and system_make_args == "":
|
||||
print("for fast compilation, run:")
|
||||
print('export SHERPA_ONNX_MAKE_ARGS="-j"; python setup.py install')
|
||||
print('Setting make_args to "-j4"')
|
||||
make_args = "-j4"
|
||||
|
||||
build_cmd = f"""
|
||||
cd {self.build_temp}
|
||||
|
||||
cmake {cmake_args} {sherpa_onnx_dir}
|
||||
|
||||
make {make_args} install/strip
|
||||
"""
|
||||
print(f"build command is:\n{build_cmd}")
|
||||
|
||||
ret = os.system(build_cmd)
|
||||
if ret != 0:
|
||||
raise Exception(
|
||||
"\nBuild sherpa-onnx failed. Please check the error message.\n"
|
||||
"You can ask for help by creating an issue on GitHub.\n"
|
||||
"\nClick:\n\thttps://github.com/k2-fsa/sherpa-onnx/issues/new\n" # noqa
|
||||
)
|
||||
@@ -1,8 +1,8 @@
|
||||
function(download_kaldi_native_fbank)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44")
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57")
|
||||
|
||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
@@ -11,11 +11,11 @@ function(download_kaldi_native_fbank)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-native-fbank
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.12.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.13.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
@@ -44,6 +44,7 @@ function(download_kaldi_native_fbank)
|
||||
INTERFACE
|
||||
${kaldi_native_fbank_SOURCE_DIR}/
|
||||
)
|
||||
install(TARGETS kaldi-native-fbank-core DESTINATION lib)
|
||||
endfunction()
|
||||
|
||||
download_kaldi_native_fbank()
|
||||
|
||||
@@ -85,6 +85,7 @@ function(download_onnxruntime)
|
||||
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||
|
||||
add_library(onnxruntime SHARED IMPORTED)
|
||||
|
||||
set_target_properties(onnxruntime PROPERTIES
|
||||
IMPORTED_LOCATION ${location_onnxruntime}
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
||||
@@ -100,6 +101,18 @@ function(download_onnxruntime)
|
||||
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
if(UNIX AND NOT APPLE)
|
||||
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*")
|
||||
elseif(APPLE)
|
||||
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/lib*dylib")
|
||||
elseif(WIN32)
|
||||
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/*.dll")
|
||||
endif()
|
||||
|
||||
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
|
||||
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
|
||||
endfunction()
|
||||
|
||||
download_onnxruntime()
|
||||
|
||||
38
cmake/pybind11.cmake
Normal file
38
cmake/pybind11.cmake
Normal file
@@ -0,0 +1,38 @@
|
||||
function(download_pybind11)
|
||||
include(FetchContent)
|
||||
|
||||
set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.2.tar.gz")
|
||||
set(pybind11_HASH "SHA256=93bd1e625e43e03028a3ea7389bba5d3f9f2596abc074b068e70f4ef9b1314ae")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download pybind11
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/pybind11-2.10.2.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/pybind11-2.10.2.tar.gz
|
||||
${PROJECT_BINARY_DIR}/pybind11-2.10.2.tar.gz
|
||||
/tmp/pybind11-2.10.2.tar.gz
|
||||
/star-fj/fangjun/download/github/pybind11-2.10.2.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(pybind11_URL "file://${f}")
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(pybind11
|
||||
URL ${pybind11_URL}
|
||||
URL_HASH ${pybind11_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(pybind11)
|
||||
if(NOT pybind11_POPULATED)
|
||||
message(STATUS "Downloading pybind11 from ${pybind11_URL}")
|
||||
FetchContent_Populate(pybind11)
|
||||
endif()
|
||||
message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")
|
||||
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endfunction()
|
||||
|
||||
download_pybind11()
|
||||
73
python-api-examples/decode-file.py
Normal file
73
python-api-examples/decode-file.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to recognize
|
||||
a single file.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download the pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import wave
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def main():
|
||||
sample_rate = 16000
|
||||
num_threads = 4
|
||||
recognizer = sherpa_onnx.OnlineRecognizer(
|
||||
tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt",
|
||||
encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx",
|
||||
decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx",
|
||||
joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx",
|
||||
num_threads=num_threads,
|
||||
sample_rate=sample_rate,
|
||||
feature_dim=80,
|
||||
)
|
||||
filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
|
||||
with wave.open(filename) as f:
|
||||
assert f.getframerate() == sample_rate, f.getframerate()
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
|
||||
duration = len(samples_float32) / sample_rate
|
||||
|
||||
start_time = time.time()
|
||||
print("Started!")
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
|
||||
stream.accept_waveform(sample_rate, samples_float32)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(sample_rate, tail_paddings)
|
||||
|
||||
stream.input_finished()
|
||||
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
|
||||
print(recognizer.get_result(stream))
|
||||
|
||||
print("Done!")
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
print(f"num_threads: {num_threads}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
75
setup.py
Normal file
75
setup.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import setuptools
|
||||
|
||||
from cmake.cmake_extension import (
|
||||
BuildExtension,
|
||||
bdist_wheel,
|
||||
cmake_extension,
|
||||
is_windows,
|
||||
)
|
||||
|
||||
|
||||
def read_long_description():
|
||||
with open("README.md", encoding="utf8") as f:
|
||||
readme = f.read()
|
||||
return readme
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open("CMakeLists.txt") as f:
|
||||
content = f.read()
|
||||
|
||||
match = re.search(r"set\(SHERPA_ONNX_VERSION (.*)\)", content)
|
||||
latest_version = match.group(1).strip('"')
|
||||
return latest_version
|
||||
|
||||
|
||||
package_name = "sherpa-onnx"
|
||||
|
||||
with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
|
||||
f.write(f"__version__ = '{get_package_version()}'\n")
|
||||
|
||||
install_requires = [
|
||||
"numpy",
|
||||
]
|
||||
|
||||
setuptools.setup(
|
||||
name=package_name,
|
||||
python_requires=">=3.6",
|
||||
install_requires=install_requires,
|
||||
version=get_package_version(),
|
||||
author="The sherpa-onnx development team",
|
||||
author_email="dpovey@gmail.com",
|
||||
package_dir={
|
||||
"sherpa_onnx": "sherpa-onnx/python/sherpa_onnx",
|
||||
},
|
||||
packages=["sherpa_onnx"],
|
||||
url="https://github.com/k2-fsa/sherpa-onnx",
|
||||
long_description=read_long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
ext_modules=[cmake_extension("_sherpa_onnx")],
|
||||
cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel},
|
||||
zip_safe=False,
|
||||
classifiers=[
|
||||
"Programming Language :: C++",
|
||||
"Programming Language :: Python",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
license="Apache licensed, as found in the LICENSE file",
|
||||
)
|
||||
|
||||
with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "w") as f:
|
||||
for line in lines:
|
||||
if "__version__" in line:
|
||||
# skip __version__ = "x.x.x"
|
||||
continue
|
||||
f.write(line)
|
||||
@@ -1 +1,4 @@
|
||||
add_subdirectory(csrc)
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
add_executable(sherpa-onnx
|
||||
add_library(sherpa-onnx-core
|
||||
features.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
@@ -9,15 +9,21 @@ add_executable(sherpa-onnx
|
||||
online-transducer-model-config.cc
|
||||
online-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
sherpa-onnx.cc
|
||||
symbol-table.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
target_link_libraries(sherpa-onnx
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
)
|
||||
|
||||
add_executable(sherpa-onnx-show-info show-onnx-info.cc)
|
||||
target_link_libraries(sherpa-onnx-show-info onnxruntime)
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
|
||||
target_link_libraries(sherpa-onnx sherpa-onnx-core)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
endif()
|
||||
|
||||
install(TARGETS sherpa-onnx-core DESTINATION lib)
|
||||
install(TARGETS sherpa-onnx DESTINATION bin)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/features.cc
|
||||
// sherpa-onnx/csrc/features.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/features.h
|
||||
// sherpa-onnx/csrc/features.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-lstm-transducer-model.cc
|
||||
// sherpa-onnx/csrc/online-lstm-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
@@ -232,7 +232,7 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> &states) {
|
||||
std::vector<Ort::Value> states) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-lstm-transducer-model.h
|
||||
// sherpa-onnx/csrc/online-lstm-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
|
||||
@@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> &states) override;
|
||||
Ort::Value features, std::vector<Ort::Value> states) override;
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) override;
|
||||
|
||||
@@ -98,7 +98,7 @@ class OnlineRecognizer::Impl {
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), states);
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
|
||||
|
||||
@@ -23,6 +23,13 @@ struct OnlineRecognizerConfig {
|
||||
OnlineTransducerModelConfig model_config;
|
||||
std::string tokens;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const std::string &tokens)
|
||||
: feat_config(feat_config), model_config(model_config), tokens(tokens) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-decoder.h
|
||||
// sherpa-onnx/csrc/online-transducer-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-greedy-search-decoder.cc
|
||||
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-greedy-search-decoder.h
|
||||
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model-config.cc
|
||||
// sherpa-onnx/csrc/online-transducer-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model-config.h
|
||||
// sherpa-onnx/csrc/online-transducer-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
@@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig {
|
||||
int32_t num_threads;
|
||||
bool debug = false;
|
||||
|
||||
OnlineTransducerModelConfig() = default;
|
||||
OnlineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
int32_t num_threads, bool debug)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model.cc
|
||||
// sherpa-onnx/csrc/online-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model.h
|
||||
// sherpa-onnx/csrc/online-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
|
||||
@@ -59,7 +59,7 @@ class OnlineTransducerModel {
|
||||
*/
|
||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||
std::vector<Ort::Value> states) = 0; // NOLINT
|
||||
|
||||
virtual Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/onnx-utils.cc
|
||||
// sherpa-onnx/csrc/onnx-utils.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/onnx-utils.h
|
||||
// sherpa-onnx/csrc/onnx-utils.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
// sherpa-onnx/csrc/show-onnx-info.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
int main() {
|
||||
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
|
||||
std::vector<std::string> providers = Ort::GetAvailableProviders();
|
||||
std::ostringstream os;
|
||||
os << "Available providers: ";
|
||||
std::string sep = "";
|
||||
for (const auto &p : providers) {
|
||||
os << sep << p;
|
||||
sep = ", ";
|
||||
}
|
||||
std::cout << os.str() << "\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa-onnx/csrc/symbol-table.cc
|
||||
// sherpa-onnx/csrc/symbol-table.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/wave-reader.cc
|
||||
// sherpa-onnx/csrc/wave-reader.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/wave-reader.h
|
||||
// sherpa-onnx/csrc/wave-reader.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
5
sherpa-onnx/python/CMakeLists.txt
Normal file
5
sherpa-onnx/python/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_subdirectory(csrc)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
29
sherpa-onnx/python/csrc/CMakeLists.txt
Normal file
29
sherpa-onnx/python/csrc/CMakeLists.txt
Normal file
@@ -0,0 +1,29 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
pybind11_add_module(_sherpa_onnx
|
||||
features.cc
|
||||
online-transducer-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
online-stream.cc
|
||||
online-recognizer.cc
|
||||
)
|
||||
|
||||
if(APPLE)
|
||||
execute_process(
|
||||
COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR
|
||||
)
|
||||
message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}")
|
||||
target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}")
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(_sherpa_onnx PRIVATE "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/sherpa_onnx/lib")
|
||||
endif()
|
||||
|
||||
target_link_libraries(_sherpa_onnx PRIVATE sherpa-onnx-core)
|
||||
|
||||
install(TARGETS _sherpa_onnx
|
||||
DESTINATION ../
|
||||
)
|
||||
23
sherpa-onnx/python/csrc/features.cc
Normal file
23
sherpa-onnx/python/csrc/features.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/features.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = FeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindFeatures(py::module *m) { PybindFeatureExtractorConfig(m); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/features.h
Normal file
16
sherpa-onnx/python/csrc/features.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/features.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindFeatures(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_
|
||||
49
sherpa-onnx/python/csrc/online-recognizer.cc
Normal file
49
sherpa-onnx/python/csrc/online-recognizer.cc
Normal file
@@ -0,0 +1,49 @@
|
||||
// sherpa-onnx/python/csrc/online-recongizer.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
using PyClass = OnlineRecognizerResult;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerResult")
|
||||
.def_property_readonly("text", [](PyClass &self) { return self.text; });
|
||||
}
|
||||
|
||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const std::string &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindOnlineRecognizer(py::module *m) {
|
||||
PybindOnlineRecognizerResult(m);
|
||||
PybindOnlineRecognizerConfig(m);
|
||||
|
||||
using PyClass = OnlineRecognizer;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizer")
|
||||
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
|
||||
.def("create_stream", &PyClass::CreateStream)
|
||||
.def("is_ready", &PyClass::IsReady)
|
||||
.def("decode_stream", &PyClass::DecodeStream)
|
||||
.def("decode_streams",
|
||||
[](PyClass &self, std::vector<OnlineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
})
|
||||
.def("get_result", &PyClass::GetResult);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-recognizer.h
Normal file
16
sherpa-onnx/python/csrc/online-recognizer.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-recongizer.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineRecognizer(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_
|
||||
21
sherpa-onnx/python/csrc/online-stream.cc
Normal file
21
sherpa-onnx/python/csrc/online-stream.cc
Normal file
@@ -0,0 +1,21 @@
|
||||
// sherpa-onnx/python/csrc/online-stream.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineStream(py::module *m) {
|
||||
using PyClass = OnlineStream;
|
||||
py::class_<PyClass>(*m, "OnlineStream")
|
||||
.def("accept_waveform",
|
||||
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
|
||||
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
||||
})
|
||||
.def("input_finished", &PyClass::InputFinished);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-stream.h
Normal file
16
sherpa-onnx/python/csrc/online-stream.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-stream.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineStream(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_
|
||||
29
sherpa-onnx/python/csrc/online-transducer-model-config.cc
Normal file
29
sherpa-onnx/python/csrc/online-transducer-model-config.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/python/csrc/online-transducer-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineTransducerModelConfig(py::module *m) {
|
||||
using PyClass = OnlineTransducerModelConfig;
|
||||
py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, int32_t, bool>(),
|
||||
py::arg("encoder_filename"), py::arg("decoder_filename"),
|
||||
py::arg("joiner_filename"), py::arg("num_threads"),
|
||||
py::arg("debug") = false)
|
||||
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
|
||||
.def_readwrite("decoder_filename", &PyClass::decoder_filename)
|
||||
.def_readwrite("joiner_filename", &PyClass::joiner_filename)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-transducer-model-config.h
Normal file
16
sherpa-onnx/python/csrc/online-transducer-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-transducer-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineTransducerModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
22
sherpa-onnx/python/csrc/sherpa-onnx.cc
Normal file
22
sherpa-onnx/python/csrc/sherpa-onnx.cc
Normal file
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/sherpa-onnx.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
m.doc() = "pybind11 binding of sherpa-onnx";
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineTransducerModelConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
14
sherpa-onnx/python/csrc/sherpa-onnx.h
Normal file
14
sherpa-onnx/python/csrc/sherpa-onnx.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// sherpa-onnx/python/csrc/sherpa-onnx.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
|
||||
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_
|
||||
8
sherpa-onnx/python/sherpa_onnx/__init__.py
Normal file
8
sherpa-onnx/python/sherpa_onnx/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from _sherpa_onnx import (
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
)
|
||||
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
96
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Normal file
96
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from _sherpa_onnx import (
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
|
||||
|
||||
def _assert_file_exists(f: str):
|
||||
assert Path(f).is_file(), f"{f} does not exist"
|
||||
|
||||
|
||||
class OnlineRecognizer(object):
|
||||
"""A class for streaming speech recognition."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokens: str,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
joiner: str,
|
||||
num_threads: int = 4,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||
to download pre-trained models for different languages, e.g., Chinese,
|
||||
English, etc.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
encoder:
|
||||
Path to ``encoder.onnx``.
|
||||
decoder:
|
||||
Path to ``decoder.onnx``.
|
||||
joiner:
|
||||
Path to ``joiner.onnx``.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
_assert_file_exists(decoder)
|
||||
_assert_file_exists(joiner)
|
||||
|
||||
assert num_threads > 0, num_threads
|
||||
|
||||
model_config = OnlineTransducerModelConfig(
|
||||
encoder_filename=encoder,
|
||||
decoder_filename=decoder,
|
||||
joiner_filename=joiner,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
|
||||
def create_stream(self):
|
||||
return self.recognizer.create_stream()
|
||||
|
||||
def decode_stream(self, s: OnlineStream):
|
||||
self.recognizer.decode_stream(s)
|
||||
|
||||
def decode_streams(self, ss: List[OnlineStream]):
|
||||
self.recognizer.decode_streams(ss)
|
||||
|
||||
def is_ready(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_ready(s)
|
||||
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text
|
||||
27
sherpa-onnx/python/tests/CMakeLists.txt
Normal file
27
sherpa-onnx/python/tests/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
function(sherpa_onnx_add_py_test source)
|
||||
get_filename_component(name ${source} NAME_WE)
|
||||
set(name "${name}_py")
|
||||
|
||||
add_test(NAME ${name}
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
|
||||
)
|
||||
|
||||
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
|
||||
|
||||
set_property(TEST ${name}
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# please sort the files in alphabetic order
|
||||
set(py_test_files
|
||||
test_feature_extractor_config.py
|
||||
test_online_transducer_model_config.py
|
||||
)
|
||||
|
||||
foreach(source IN LISTS py_test_files)
|
||||
sherpa_onnx_add_py_test(${source})
|
||||
endforeach()
|
||||
|
||||
29
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file
29
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# sherpa-onnx/python/tests/test_feature_extractor_config.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_feature_extractor_config_py
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
class TestFeatureExtractorConfig(unittest.TestCase):
|
||||
def test_default_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig()
|
||||
assert config.sampling_rate == 16000, config.sampling_rate
|
||||
assert config.feature_dim == 80, config.feature_dim
|
||||
print(config)
|
||||
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
|
||||
assert config.sampling_rate == 8000, config.sampling_rate
|
||||
assert config.feature_dim == 40, config.feature_dim
|
||||
print(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,32 @@
|
||||
# sherpa-onnx/python/tests/test_online_transducer_model_config.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_online_transducer_model_config_py
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
class TestOnlineTransducerModelConfig(unittest.TestCase):
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.OnlineTransducerModelConfig(
|
||||
encoder_filename="encoder.onnx",
|
||||
decoder_filename="decoder.onnx",
|
||||
joiner_filename="joiner.onnx",
|
||||
num_threads=8,
|
||||
debug=True,
|
||||
)
|
||||
assert config.encoder_filename == "encoder.onnx", config.encoder_filename
|
||||
assert config.decoder_filename == "decoder.onnx", config.decoder_filename
|
||||
assert config.joiner_filename == "joiner.onnx", config.joiner_filename
|
||||
assert config.num_threads == 8, config.num_threads
|
||||
assert config.debug is True, config.debug
|
||||
print(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user