Add Python API (#31)
This commit is contained in:
27
sherpa-onnx/python/tests/CMakeLists.txt
Normal file
27
sherpa-onnx/python/tests/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
function(sherpa_onnx_add_py_test source)
|
||||
get_filename_component(name ${source} NAME_WE)
|
||||
set(name "${name}_py")
|
||||
|
||||
add_test(NAME ${name}
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
|
||||
)
|
||||
|
||||
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
|
||||
|
||||
set_property(TEST ${name}
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# please sort the files in alphabetic order
|
||||
set(py_test_files
|
||||
test_feature_extractor_config.py
|
||||
test_online_transducer_model_config.py
|
||||
)
|
||||
|
||||
foreach(source IN LISTS py_test_files)
|
||||
sherpa_onnx_add_py_test(${source})
|
||||
endforeach()
|
||||
|
||||
29
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file
29
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# sherpa-onnx/python/tests/test_feature_extractor_config.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_feature_extractor_config_py
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
class TestFeatureExtractorConfig(unittest.TestCase):
|
||||
def test_default_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig()
|
||||
assert config.sampling_rate == 16000, config.sampling_rate
|
||||
assert config.feature_dim == 80, config.feature_dim
|
||||
print(config)
|
||||
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
|
||||
assert config.sampling_rate == 8000, config.sampling_rate
|
||||
assert config.feature_dim == 40, config.feature_dim
|
||||
print(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,32 @@
|
||||
# sherpa-onnx/python/tests/test_online_transducer_model_config.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_online_transducer_model_config_py
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
class TestOnlineTransducerModelConfig(unittest.TestCase):
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.OnlineTransducerModelConfig(
|
||||
encoder_filename="encoder.onnx",
|
||||
decoder_filename="decoder.onnx",
|
||||
joiner_filename="joiner.onnx",
|
||||
num_threads=8,
|
||||
debug=True,
|
||||
)
|
||||
assert config.encoder_filename == "encoder.onnx", config.encoder_filename
|
||||
assert config.decoder_filename == "decoder.onnx", config.decoder_filename
|
||||
assert config.joiner_filename == "joiner.onnx", config.joiner_filename
|
||||
assert config.num_threads == 8, config.num_threads
|
||||
assert config.debug is True, config.debug
|
||||
print(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user