Refactor hotwords,support loading hotwords from file (#296)
This commit is contained in:
@@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const std::string &, int32_t, float>(),
|
||||
const std::string &, int32_t, const std::string &, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 1.5)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("context_score", &PyClass::context_score)
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) {
|
||||
[](const PyClass &self) { return self.CreateStream(); })
|
||||
.def(
|
||||
"create_stream",
|
||||
[](PyClass &self,
|
||||
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||
return self.CreateStream(contexts_list);
|
||||
[](PyClass &self, const std::string &hotwords) {
|
||||
return self.CreateStream(hotwords);
|
||||
},
|
||||
py::arg("contexts_list"))
|
||||
py::arg("hotwords"))
|
||||
.def("decode_stream", &PyClass::DecodeStream)
|
||||
.def("decode_streams",
|
||||
[](const PyClass &self, std::vector<OfflineStream *> ss) {
|
||||
|
||||
@@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
using PyClass = OnlineModelConfig;
|
||||
py::class_<PyClass>(*m, "OnlineModelConfig")
|
||||
.def(py::init<const OnlineTransducerModelConfig &,
|
||||
const OnlineParaformerModelConfig &, std::string &, int32_t,
|
||||
bool, const std::string &, const std::string &>(),
|
||||
const OnlineParaformerModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
|
||||
@@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &, bool,
|
||||
const std::string &, int32_t, float>(),
|
||||
const std::string &, int32_t, const std::string &, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("context_score", &PyClass::context_score)
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) {
|
||||
[](const PyClass &self) { return self.CreateStream(); })
|
||||
.def(
|
||||
"create_stream",
|
||||
[](PyClass &self,
|
||||
const std::vector<std::vector<int32_t>> &contexts_list) {
|
||||
return self.CreateStream(contexts_list);
|
||||
[](PyClass &self, const std::string &hotwords) {
|
||||
return self.CreateStream(hotwords);
|
||||
},
|
||||
py::arg("contexts_list"))
|
||||
py::arg("hotwords"))
|
||||
.def("is_ready", &PyClass::IsReady)
|
||||
.def("decode_stream", &PyClass::DecodeStream)
|
||||
.def("decode_streams",
|
||||
|
||||
@@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream
|
||||
|
||||
from .offline_recognizer import OfflineRecognizer
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
from .utils import encode_contexts
|
||||
from .utils import text2token
|
||||
|
||||
55
sherpa-onnx/python/sherpa_onnx/cli.py
Normal file
55
sherpa-onnx/python/sherpa_onnx/cli.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
import logging
|
||||
import click
|
||||
from pathlib import Path
|
||||
from sherpa_onnx import text2token
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""
|
||||
The shell entry point to sherpa-onnx.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
@cli.command(name="text2token")
|
||||
@click.argument("input", type=click.Path(exists=True, dir_okay=False))
|
||||
@click.argument("output", type=click.Path())
|
||||
@click.option(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to tokens.txt.",
|
||||
)
|
||||
@click.option(
|
||||
"--tokens-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
||||
)
|
||||
@click.option(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
|
||||
)
|
||||
def encode_text(
|
||||
input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path
|
||||
):
|
||||
"""
|
||||
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
||||
"""
|
||||
texts = []
|
||||
with open(input, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
texts.append(line.strip())
|
||||
encoded_texts = text2token(
|
||||
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
||||
)
|
||||
with open(output, "w", encoding="utf8") as f:
|
||||
for txt in encoded_texts:
|
||||
f.write(" ".join(txt) + "\n")
|
||||
@@ -43,7 +43,8 @@ class OfflineRecognizer(object):
|
||||
feature_dim: int = 80,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
context_score: float = 1.5,
|
||||
hotwords_file: str = "",
|
||||
hotwords_score: float = 1.5,
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
@@ -105,7 +106,8 @@ class OfflineRecognizer(object):
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
context_score=context_score,
|
||||
hotwords_file=hotwords_file,
|
||||
hotwords_score=hotwords_score,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
@@ -379,11 +381,11 @@ class OfflineRecognizer(object):
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||
if contexts_list is None:
|
||||
def create_stream(self, hotwords: Optional[str] = None):
|
||||
if hotwords is None:
|
||||
return self.recognizer.create_stream()
|
||||
else:
|
||||
return self.recognizer.create_stream(contexts_list)
|
||||
return self.recognizer.create_stream(hotwords)
|
||||
|
||||
def decode_stream(self, s: OfflineStream):
|
||||
self.recognizer.decode_stream(s)
|
||||
|
||||
@@ -42,7 +42,8 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
context_score: float = 1.5,
|
||||
hotwords_score: float = 1.5,
|
||||
hotwords_file: str = "",
|
||||
provider: str = "cpu",
|
||||
model_type: str = "",
|
||||
):
|
||||
@@ -138,7 +139,8 @@ class OnlineRecognizer(object):
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
max_active_paths=max_active_paths,
|
||||
context_score=context_score,
|
||||
hotwords_score=hotwords_score,
|
||||
hotwords_file=hotwords_file,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -248,11 +250,11 @@ class OnlineRecognizer(object):
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||
if contexts_list is None:
|
||||
def create_stream(self, hotwords: Optional[str] = None):
|
||||
if hotwords is None:
|
||||
return self.recognizer.create_stream()
|
||||
else:
|
||||
return self.recognizer.create_stream(contexts_list)
|
||||
return self.recognizer.create_stream(hotwords)
|
||||
|
||||
def decode_stream(self, s: OnlineStream):
|
||||
self.recognizer.decode_stream(s)
|
||||
|
||||
@@ -1,74 +1,95 @@
|
||||
from typing import Dict, List, Optional
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def encode_contexts(
|
||||
modeling_unit: str,
|
||||
contexts: List[str],
|
||||
sp: Optional["SentencePieceProcessor"] = None,
|
||||
tokens_table: Optional[Dict[str, int]] = None,
|
||||
) -> List[List[int]]:
|
||||
def text2token(
|
||||
texts: List[str],
|
||||
tokens: str,
|
||||
tokens_type: str = "cjkchar",
|
||||
bpe_model: Optional[str] = None,
|
||||
output_ids: bool = False,
|
||||
) -> List[List[Union[str, int]]]:
|
||||
"""
|
||||
Encode the given contexts (a list of string) to a list of a list of token ids.
|
||||
Encode the given texts (a list of string) to a list of a list of tokens.
|
||||
|
||||
Args:
|
||||
modeling_unit:
|
||||
The valid values are bpe, char, bpe+char.
|
||||
Note: char here means characters in CJK languages, not English like languages.
|
||||
contexts:
|
||||
texts:
|
||||
The given contexts list (a list of string).
|
||||
sp:
|
||||
An instance of SentencePieceProcessor.
|
||||
tokens_table:
|
||||
The tokens_table containing the tokens and the corresponding ids.
|
||||
tokens:
|
||||
The path of the tokens.txt.
|
||||
tokens_type:
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
bpe_model:
|
||||
The path of the bpe model. Only required when tokens_type is bpe or
|
||||
cjkchar+bpe.
|
||||
output_ids:
|
||||
True to output token ids otherwise tokens.
|
||||
Returns:
|
||||
Return the contexts_list, it is a list of a list of token ids.
|
||||
Return the encoded texts, it is a list of a list of token ids if output_ids
|
||||
is True, or it is a list of list of tokens.
|
||||
"""
|
||||
contexts_list = []
|
||||
if "bpe" in modeling_unit:
|
||||
assert sp is not None
|
||||
if "char" in modeling_unit:
|
||||
assert tokens_table is not None
|
||||
assert len(tokens_table) > 0, len(tokens_table)
|
||||
assert Path(tokens).is_file(), f"File not exists, {tokens}"
|
||||
tokens_table = {}
|
||||
with open(tokens, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
toks = line.strip().split()
|
||||
assert len(toks) == 2, len(toks)
|
||||
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
|
||||
tokens_table[toks[0]] = int(toks[1])
|
||||
|
||||
if "char" == modeling_unit:
|
||||
for context in contexts:
|
||||
assert ' ' not in context
|
||||
ids = [
|
||||
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
|
||||
for txt in context
|
||||
]
|
||||
contexts_list.append(ids)
|
||||
elif "bpe" == modeling_unit:
|
||||
contexts_list = sp.encode(contexts, out_type=int)
|
||||
if "bpe" in tokens_type:
|
||||
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
|
||||
texts_list: List[List[str]] = []
|
||||
|
||||
if tokens_type == "cjkchar":
|
||||
texts_list = [list("".join(text.split())) for text in texts]
|
||||
elif tokens_type == "bpe":
|
||||
texts_list = sp.encode(texts, out_type=str)
|
||||
else:
|
||||
assert modeling_unit == "bpe+char", modeling_unit
|
||||
|
||||
assert (
|
||||
tokens_type == "cjkchar+bpe"
|
||||
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
|
||||
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
for context in contexts:
|
||||
for text in texts:
|
||||
# Example:
|
||||
# txt = "你好 ITS'S OKAY 的"
|
||||
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||
chars = pattern.split(context.upper())
|
||||
chars = pattern.split(text)
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
ids = []
|
||||
text_list = []
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
ids.append(
|
||||
tokens_table[ch_or_w]
|
||||
if ch_or_w in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
text_list.append(ch_or_w)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
ids.append(
|
||||
tokens_table[p]
|
||||
if p in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
contexts_list.append(ids)
|
||||
return contexts_list
|
||||
text_list += sp.encode_as_pieces(ch_or_w)
|
||||
texts_list.append(text_list)
|
||||
|
||||
result: List[List[Union[int, str]]] = []
|
||||
for text in texts_list:
|
||||
text_list = []
|
||||
contain_oov = False
|
||||
for txt in text:
|
||||
if txt in tokens_table:
|
||||
text_list.append(tokens_table[txt] if output_ids else txt)
|
||||
else:
|
||||
print(f"OOV token : {txt}, skipping text : {text}.")
|
||||
contain_oov = True
|
||||
break
|
||||
if contain_oov:
|
||||
continue
|
||||
else:
|
||||
result.append(text_list)
|
||||
return result
|
||||
|
||||
@@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source)
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
|
||||
WORKING_DIRECTORY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
|
||||
|
||||
set_property(TEST ${name}
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
|
||||
)
|
||||
endfunction()
|
||||
|
||||
@@ -21,6 +23,7 @@ set(py_test_files
|
||||
test_offline_recognizer.py
|
||||
test_online_recognizer.py
|
||||
test_online_transducer_model_config.py
|
||||
test_text2token.py
|
||||
)
|
||||
|
||||
foreach(source IN LISTS py_test_files)
|
||||
|
||||
121
sherpa-onnx/python/tests/test_text2token.py
Normal file
121
sherpa-onnx/python/tests/test_text2token.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# sherpa-onnx/python/tests/test_text2token.py
|
||||
#
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_text2token_py
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
d = "/tmp/sherpa-test-data"
|
||||
# Please refer to
|
||||
# https://github.com/pkufool/sherpa-test-data
|
||||
# to download test data for testing
|
||||
|
||||
|
||||
class TestText2Token(unittest.TestCase):
|
||||
def test_bpe(self):
|
||||
tokens = f"{d}/text2token/tokens_en.txt"
|
||||
bpe_model = f"{d}/text2token/bpe_en.model"
|
||||
|
||||
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_bpe().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["HELLO WORLD", "I LOVE YOU"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="bpe",
|
||||
bpe_model=bpe_model,
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["▁HE", "LL", "O", "▁WORLD"],
|
||||
["▁I", "▁LOVE", "▁YOU"],
|
||||
], encoded_texts
|
||||
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="bpe",
|
||||
bpe_model=bpe_model,
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
|
||||
|
||||
def test_cjkchar(self):
|
||||
tokens = f"{d}/text2token/tokens_cn.txt"
|
||||
|
||||
if not Path(tokens).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_cjkchar().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["世界人民大团结", "中国 VS 美国"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts, tokens=tokens, tokens_type="cjkchar"
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["世", "界", "人", "民", "大", "团", "结"],
|
||||
["中", "国", "V", "S", "美", "国"],
|
||||
], encoded_texts
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar",
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [
|
||||
[379, 380, 72, 874, 93, 1251, 489],
|
||||
[262, 147, 3423, 2476, 21, 147],
|
||||
], encoded_ids
|
||||
|
||||
def test_cjkchar_bpe(self):
|
||||
tokens = f"{d}/text2token/tokens_mix.txt"
|
||||
bpe_model = f"{d}/text2token/bpe_mix.model"
|
||||
|
||||
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
|
||||
print(
|
||||
f"No test data found, skipping test_cjkchar_bpe().\n"
|
||||
f"You can download the test data by: \n"
|
||||
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
|
||||
)
|
||||
return
|
||||
|
||||
texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
|
||||
encoded_texts = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar+bpe",
|
||||
bpe_model=bpe_model,
|
||||
)
|
||||
assert encoded_texts == [
|
||||
["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
|
||||
["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
|
||||
], encoded_texts
|
||||
encoded_ids = sherpa_onnx.text2token(
|
||||
texts,
|
||||
tokens=tokens,
|
||||
tokens_type="cjkchar+bpe",
|
||||
bpe_model=bpe_model,
|
||||
output_ids=True,
|
||||
)
|
||||
assert encoded_ids == [
|
||||
[1368, 1392, 557, 680, 275, 178, 475],
|
||||
[685, 736, 275, 178, 179, 921, 736],
|
||||
], encoded_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user