Refactor hotwords,support loading hotwords from file (#296)
This commit is contained in:
@@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source)
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
|
||||
WORKING_DIRECTORY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
|
||||
|
||||
set_property(TEST ${name}
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
)
|
||||
endfunction()
|
||||
|
||||
@@ -21,6 +23,7 @@ set(py_test_files
|
||||
test_offline_recognizer.py
|
||||
test_online_recognizer.py
|
||||
test_online_transducer_model_config.py
|
||||
test_text2token.py
|
||||
)
|
||||
|
||||
foreach(source IN LISTS py_test_files)
|
||||
|
||||
121
sherpa-onnx/python/tests/test_text2token.py
Normal file
121
sherpa-onnx/python/tests/test_text2token.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# sherpa-onnx/python/tests/test_text2token.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_text2token_py
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
d = "/tmp/sherpa-test-data"
|
||||
# Please refer to
|
||||
# https://github.com/pkufool/sherpa-test-data
|
||||
# to download test data for testing
|
||||
|
||||
|
||||
class TestText2Token(unittest.TestCase):
|
||||
def test_bpe(self):
|
||||
tokens = f"{d}/text2token/tokens_en.txt"
|
||||
bpe_model = f"{d}/text2token/bpe_en.model"
|
||||
|
||||
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_bpe().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["HELLO WORLD", "I LOVE YOU"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="bpe",
|
||||
bpe_model=bpe_model,
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["▁HE", "LL", "O", "▁WORLD"],
|
||||
["▁I", "▁LOVE", "▁YOU"],
|
||||
], encoded_texts
|
||||
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="bpe",
|
||||
bpe_model=bpe_model,
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
|
||||
|
||||
def test_cjkchar(self):
|
||||
tokens = f"{d}/text2token/tokens_cn.txt"
|
||||
|
||||
if not Path(tokens).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_cjkchar().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["世界人民大团结", "中国 VS 美国"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts, tokens=tokens, tokens_type="cjkchar"
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["世", "界", "人", "民", "大", "团", "结"],
|
||||
["中", "国", "V", "S", "美", "国"],
|
||||
], encoded_texts
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar",
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [
|
||||
[379, 380, 72, 874, 93, 1251, 489],
|
||||
[262, 147, 3423, 2476, 21, 147],
|
||||
], encoded_ids
|
||||
|
||||
def test_cjkchar_bpe(self):
|
||||
tokens = f"{d}/text2token/tokens_mix.txt"
|
||||
bpe_model = f"{d}/text2token/bpe_mix.model"
|
||||
|
||||
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_cjkchar_bpe().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar+bpe",
|
||||
bpe_model=bpe_model,
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
|
||||
["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
|
||||
], encoded_texts
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar+bpe",
|
||||
bpe_model=bpe_model,
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [
|
||||
[1368, 1392, 557, 680, 275, 178, 475],
|
||||
[685, 736, 275, 178, 179, 921, 736],
|
||||
], encoded_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user