Refactor hotwords,support loading hotwords from file (#296)

This commit is contained in:
Wei Kang
2023-09-14 19:33:17 +08:00
committed by GitHub
parent 087367d7fe
commit 47184f9db7
34 changed files with 803 additions and 300 deletions

View File

@@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source)
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
WORKING_DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}
)
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
set_property(TEST ${name}
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
)
endfunction()
@@ -21,6 +23,7 @@ set(py_test_files
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
test_text2token.py
)
foreach(source IN LISTS py_test_files)

View File

@@ -0,0 +1,121 @@
# sherpa-onnx/python/tests/test_text2token.py
#
# Copyright (c) 2023 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_text2token_py
import unittest
from pathlib import Path
import sherpa_onnx
d = "/tmp/sherpa-test-data"
# Please refer to
# https://github.com/pkufool/sherpa-test-data
# to download test data for testing
class TestText2Token(unittest.TestCase):
def test_bpe(self):
tokens = f"{d}/text2token/tokens_en.txt"
bpe_model = f"{d}/text2token/bpe_en.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["HELLO WORLD", "I LOVE YOU"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["▁HE", "LL", "O", "▁WORLD"],
["▁I", "▁LOVE", "▁YOU"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
def test_cjkchar(self):
tokens = f"{d}/text2token/tokens_cn.txt"
if not Path(tokens).is_file():
print(
f"No test data found, skipping test_cjkchar().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民大团结", "中国 VS 美国"]
encoded_texts = sherpa_onnx.text2token(
texts, tokens=tokens, tokens_type="cjkchar"
)
assert encoded_texts == [
["", "", "", "", "", "", ""],
["", "", "V", "S", "", ""],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar",
output_ids=True,
)
assert encoded_ids == [
[379, 380, 72, 874, 93, 1251, 489],
[262, 147, 3423, 2476, 21, 147],
], encoded_ids
def test_cjkchar_bpe(self):
tokens = f"{d}/text2token/tokens_mix.txt"
bpe_model = f"{d}/text2token/bpe_mix.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_cjkchar_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["", "", "", "", "▁GO", "ES", "▁TOGETHER"],
["", "", "▁GO", "ES", "▁WITH", "", ""],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [
[1368, 1392, 557, 680, 275, 178, 475],
[685, 736, 275, 178, 179, 921, 736],
], encoded_ids
if __name__ == "__main__":
unittest.main()