diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index ba2957dc..69a94804 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 8e979e49..f43614f5 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-offline-tts.sh b/.github/scripts/test-offline-tts.sh index dca90b1d..e611c3d2 100755 --- a/.github/scripts/test-offline-tts.sh +++ b/.github/scripts/test-offline-tts.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-offline-whisper.sh b/.github/scripts/test-offline-whisper.sh index 81621e74..e2987b5d 100755 --- a/.github/scripts/test-offline-whisper.sh +++ b/.github/scripts/test-offline-whisper.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-online-ctc.sh b/.github/scripts/test-online-ctc.sh index c28d2b3c..a81b6ff2 100755 --- a/.github/scripts/test-online-ctc.sh +++ b/.github/scripts/test-online-ctc.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-online-paraformer.sh b/.github/scripts/test-online-paraformer.sh index 93574e3f..6b5d3cf6 100755 --- a/.github/scripts/test-online-paraformer.sh +++ b/.github/scripts/test-online-paraformer.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh index 5a08d332..f8a76d76 100755 --- a/.github/scripts/test-online-transducer.sh +++ b/.github/scripts/test-online-transducer.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index a39d0c6b..e6d4c17a 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + log "test online NeMo CTC" url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 diff --git a/.github/scripts/test-spoken-language-identification.sh b/.github/scripts/test-spoken-language-identification.sh index 75d1364d..fec3d438 100755 --- a/.github/scripts/test-spoken-language-identification.sh +++ b/.github/scripts/test-spoken-language-identification.sh @@ -8,6 +8,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +export GIT_CLONE_PROTECTION_ACTIVE=false + echo "EXE is $EXE" echo "PATH: $PATH" diff --git a/CMakeLists.txt b/CMakeLists.txt index 99944872..a3e4ffef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,7 @@ endif() include(kaldi-native-fbank) include(kaldi-decoder) include(onnxruntime) +include(simple-sentencepiece) set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") diff --git a/build-ios-no-tts.sh b/build-ios-no-tts.sh index 9bbe5b26..fecdb6b5 100755 --- a/build-ios-no-tts.sh +++ b/build-ios-no-tts.sh @@ -126,7 +126,7 @@ echo "Generate xcframework" mkdir -p "build/simulator/lib" for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ - libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a; do + libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a libssentencepiece_core.a; do lipo -create build/simulator_arm64/lib/${f} \ build/simulator_x86_64/lib/${f} \ -output build/simulator/lib/${f} @@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \ build/simulator/lib/libsherpa-onnx-core.a \ build/simulator/lib/libsherpa-onnx-fst.a \ build/simulator/lib/libsherpa-onnx-kaldifst-core.a \ - build/simulator/lib/libkaldi-decoder-core.a + build/simulator/lib/libkaldi-decoder-core.a \ + build/simulator/lib/libssentencepiece_core.a libtool -static -o build/os64/sherpa-onnx.a \ build/os64/lib/libkaldi-native-fbank-core.a \ @@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \ build/os64/lib/libsherpa-onnx-core.a \ build/os64/lib/libsherpa-onnx-fst.a \ build/os64/lib/libsherpa-onnx-kaldifst-core.a \ - build/os64/lib/libkaldi-decoder-core.a + build/os64/lib/libkaldi-decoder-core.a \ + build/os64/lib/libssentencepiece_core.a rm -rf sherpa-onnx.xcframework diff --git a/build-ios.sh b/build-ios.sh index ac81b4fb..61b06b96 100755 --- a/build-ios.sh +++ b/build-ios.sh @@ -129,7 +129,7 @@ echo "Generate xcframework" mkdir -p "build/simulator/lib" for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ - libsherpa-onnx-fstfar.a \ + libsherpa-onnx-fstfar.a libssentencepiece_core.a \ libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \ libucd.a libpiper_phonemize.a libespeak-ng.a; do lipo -create build/simulator_arm64/lib/${f} \ @@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \ build/simulator/lib/libucd.a \ build/simulator/lib/libpiper_phonemize.a \ build/simulator/lib/libespeak-ng.a \ + build/simulator/lib/libssentencepiece_core.a libtool -static -o build/os64/sherpa-onnx.a \ build/os64/lib/libkaldi-native-fbank-core.a \ @@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \ build/os64/lib/libucd.a \ build/os64/lib/libpiper_phonemize.a \ build/os64/lib/libespeak-ng.a \ + build/os64/lib/libssentencepiece_core.a rm -rf sherpa-onnx.xcframework diff --git a/build-swift-macos.sh b/build-swift-macos.sh index 1b1867c5..cebfa295 100755 --- a/build-swift-macos.sh +++ b/build-swift-macos.sh @@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \ ./install/lib/libkaldi-decoder-core.a \ ./install/lib/libucd.a \ ./install/lib/libpiper_phonemize.a \ - ./install/lib/libespeak-ng.a + ./install/lib/libespeak-ng.a \ + ./install/lib/libssentencepiece_core.a diff --git a/cmake/simple-sentencepiece.cmake b/cmake/simple-sentencepiece.cmake new file mode 100644 index 00000000..a845690a --- /dev/null +++ b/cmake/simple-sentencepiece.cmake @@ -0,0 +1,63 @@ +function(download_simple_sentencepiece) + include(FetchContent) + + set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz") + set(simple-sentencepiece_URL2 "https://hub.nauu.cf/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz") + set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af") + + # If you don't have access to the Internet, + # please pre-download simple-sentencepiece + set(possible_file_locations + $ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz + ${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz + ${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz + /tmp/simple-sentencepiece-0.7.tar.gz + /star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(simple-sentencepiece_URL "${f}") + file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL) + message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}") + set(simple-sentencepiece_URL2) + break() + endif() + endforeach() + + set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(simple-sentencepiece + URL + ${simple-sentencepiece_URL} + ${simple-sentencepiece_URL2} + URL_HASH + ${simple-sentencepiece_HASH} + ) + + FetchContent_GetProperties(simple-sentencepiece) + if(NOT simple-sentencepiece_POPULATED) + message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}") + FetchContent_Populate(simple-sentencepiece) + endif() + message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}") + add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(ssentencepiece_core + PUBLIC + ${simple-sentencepiece_SOURCE_DIR}/ + ) + + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) + install(TARGETS ssentencepiece_core DESTINATION ..) + else() + install(TARGETS ssentencepiece_core DESTINATION lib) + endif() + + if(WIN32 AND BUILD_SHARED_LIBS) + install(TARGETS ssentencepiece_core DESTINATION bin) + endif() +endfunction() + +download_simple_sentencepiece() diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 41a37ee3..e7946598 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() { function testOnlineAsr() { if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then git lfs install - git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 + GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 fi if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then diff --git a/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props b/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props index e81f4b62..9217cefa 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props +++ b/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props @@ -18,6 +18,7 @@ piper_phonemize.lib; espeak-ng.lib; ucd.lib; + ssentencepiece_core.lib; diff --git a/mfc-examples/NonStreamingTextToSpeech/sherpa-onnx-deps.props b/mfc-examples/NonStreamingTextToSpeech/sherpa-onnx-deps.props index e81f4b62..9217cefa 100644 --- a/mfc-examples/NonStreamingTextToSpeech/sherpa-onnx-deps.props +++ b/mfc-examples/NonStreamingTextToSpeech/sherpa-onnx-deps.props @@ -18,6 +18,7 @@ piper_phonemize.lib; espeak-ng.lib; ucd.lib; + ssentencepiece_core.lib; diff --git a/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props b/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props index e81f4b62..9217cefa 100644 --- a/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props +++ b/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props @@ -18,6 +18,7 @@ piper_phonemize.lib; espeak-ng.lib; ucd.lib; + ssentencepiece_core.lib; diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index a058d843..0f87284e 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -110,11 +110,9 @@ def get_args(): type=str, default="", help=""" - The file containing hotwords, one words/phrases per line, and for each - phrase the bpe/cjkchar are separated by a space. For example: - - ▁HE LL O ▁WORLD - 你 好 世 界 + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 """, ) @@ -128,6 +126,28 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + parser.add_argument( "--encoder", default="", @@ -347,6 +367,8 @@ def main(): decoding_method=args.decoding_method, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, blank_penalty=args.blank_penalty, debug=args.debug, ) diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 298d5961..b9ab3b98 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -198,11 +198,9 @@ def get_args(): type=str, default="", help=""" - The file containing hotwords, one words/phrases per line, and for each - phrase the bpe/cjkchar are separated by a space. For example: - - ▁HE LL O ▁WORLD - 你 好 世 界 + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 """, ) @@ -216,6 +214,28 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + parser.add_argument( "--blank-penalty", type=float, @@ -302,6 +322,8 @@ def main(): lm_scale=args.lm_scale, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, blank_penalty=args.blank_penalty, ) elif args.zipformer2_ctc: diff --git a/scripts/export_bpe_vocab.py b/scripts/export_bpe_vocab.py new file mode 100755 index 00000000..6267e5e3 --- /dev/null +++ b/scripts/export_bpe_vocab.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +from typing import Dict + +try: + import sentencepiece as spm +except ImportError: + print('Please run') + print(' pip install sentencepiece') + print('before you continue') + raise + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + type=str, + help="The path to the bpe model.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + model_file = args.bpe_model + + vocab_file = model_file.replace(".model", ".vocab") + + sp = spm.SentencePieceProcessor() + sp.Load(model_file) + vocabs = [sp.IdToPiece(id) for id in range(sp.GetPieceSize())] + with open(vocab_file, "w") as vfile: + for v in vocabs: + id = sp.PieceToId(v) + vfile.write(f"{v}\t{sp.GetScore(id)}\n") + print(f"Vocabulary file is written to {vocab_file}") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index fc32e5a4..4ed2cb11 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -165,6 +165,7 @@ endif() target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core kaldi-decoder-core + ssentencepiece_core ) if(SHERPA_ONNX_ENABLE_GPU) @@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) pad-sequence-test.cc slice-test.cc stack-test.cc + text2token-test.cc transpose-test.cc unbind-test.cc utfcpp-test.cc diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 5f4a6770..b85a0a9f 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) { "Valid values are: transducer, paraformer, nemo_ctc, whisper, " "tdnn, zipformer2_ctc" "All other values lead to loading the model twice."); + po->Register("modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); + po->Register("bpe-vocab", &bpe_vocab, + "The vocabulary generated by google's sentencepiece program. " + "It is a file has two columns, one is the token, the other is " + "the log probability, you can get it from the directory where " + "your bpe model is generated. Only used when hotwords provided " + "and the modeling unit is bpe or cjkchar+bpe"); } bool OfflineModelConfig::Validate() const { @@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const { return false; } + if (!modeling_unit.empty() && + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str()); + return false; + } + } + if (!paraformer.model.empty()) { return paraformer.Validate(); } @@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const { os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; os << "provider=\"" << provider << "\", "; - os << "model_type=\"" << model_type << "\")"; + os << "model_type=\"" << model_type << "\", "; + os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 9750642f..93ea7fd0 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -41,6 +41,9 @@ struct OfflineModelConfig { // All other values are invalid and lead to loading the model twice. std::string model_type; + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + OfflineModelConfig() = default; OfflineModelConfig(const OfflineTransducerModelConfig &transducer, const OfflineParaformerModelConfig ¶former, @@ -50,7 +53,9 @@ struct OfflineModelConfig { const OfflineZipformerCtcModelConfig &zipformer_ctc, const OfflineWenetCtcModelConfig &wenet_ctc, const std::string &tokens, int32_t num_threads, bool debug, - const std::string &provider, const std::string &model_type) + const std::string &provider, const std::string &model_type, + const std::string &modeling_unit, + const std::string &bpe_vocab) : transducer(transducer), paraformer(paraformer), nemo_ctc(nemo_ctc), @@ -62,7 +67,9 @@ struct OfflineModelConfig { num_threads(num_threads), debug(debug), provider(provider), - model_type(model_type) {} + model_type(model_type), + modeling_unit(modeling_unit), + bpe_vocab(bpe_vocab) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 68ec63a3..5051c8b6 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -31,6 +31,7 @@ #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { : config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config_.model_config)) { - if (!config_.hotwords_file.empty()) { - InitHotwords(); - } if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique( model_.get(), config_.blank_penalty); @@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { lm_ = OfflineLM::Create(config.lm_config); } + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, config_.lm_config.scale, config_.blank_penalty); @@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { lm_ = OfflineLM::Create(mgr, config.lm_config); } + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(mgr); + } + decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, config_.lm_config.scale, config_.blank_penalty); @@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); std::istringstream is(hws); std::vector> current; - if (!EncodeHotwords(is, symbol_table_, ¤t)) { + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), ¤t)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { exit(-1); } - if (!EncodeHotwords(is, symbol_table_, &hotwords_)) { - SHERPA_ONNX_LOGE("Encode hotwords failed."); - exit(-1); + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); } hotwords_graph_ = std::make_shared(hotwords_, config_.hotwords_score); } +#if __ANDROID_API__ >= 9 + void InitHotwords(AAssetManager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istringstream is(std::string(buf.begin(), buf.end())); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = + std::make_shared(hotwords_, config_.hotwords_score); + } +#endif + private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; std::vector> hotwords_; ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; std::unique_ptr model_; std::unique_ptr decoder_; std::unique_ptr lm_; diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 8005cc85..d6ba4905 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register( "hotwords-file", &hotwords_file, - "The file containing hotwords, one words/phrases per line, and for each" - "phrase the bpe/cjkchar are separated by a space. For example: " - "▁HE LL O ▁WORLD" - "你 好 世 界"); + "The file containing hotwords, one words/phrases per line, For example: " + "HELLO WORLD" + "你好世界"); po->Register("hotwords-score", &hotwords_score, "The bonus score for each token in context word/phrase. " diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index bc4d55dc..5ea24bab 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); + po->Register("modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); + + po->Register("bpe-vocab", &bpe_vocab, + "The vocabulary generated by google's sentencepiece program. " + "It is a file has two columns, one is the token, the other is " + "the log probability, you can get it from the directory where " + "your bpe model is generated. Only used when hotwords provided " + "and the modeling unit is bpe or cjkchar+bpe"); + po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " "Valid values are: conformer, lstm, zipformer, zipformer2, " @@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const { return false; } + if (!modeling_unit.empty() && + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str()); + return false; + } + } + if (!paraformer.encoder.empty()) { return paraformer.Validate(); } @@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const { os << "warm_up=" << warm_up << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; os << "provider=\"" << provider << "\", "; - os << "model_type=\"" << model_type << "\")"; + os << "model_type=\"" << model_type << "\", "; + os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 08acf773..1509bd5b 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -37,6 +37,13 @@ struct OnlineModelConfig { // All other values are invalid and lead to loading the model twice. std::string model_type; + // Valid values: + // - cjkchar + // - bpe + // - cjkchar+bpe + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + OnlineModelConfig() = default; OnlineModelConfig(const OnlineTransducerModelConfig &transducer, const OnlineParaformerModelConfig ¶former, @@ -45,7 +52,9 @@ struct OnlineModelConfig { const OnlineNeMoCtcModelConfig &nemo_ctc, const std::string &tokens, int32_t num_threads, int32_t warm_up, bool debug, const std::string &provider, - const std::string &model_type) + const std::string &model_type, + const std::string &modeling_unit, + const std::string &bpe_vocab) : transducer(transducer), paraformer(paraformer), wenet_ctc(wenet_ctc), @@ -56,7 +65,9 @@ struct OnlineModelConfig { warm_up(warm_up), debug(debug), provider(provider), - model_type(model_type) {} + model_type(model_type), + modeling_unit(modeling_unit), + bpe_vocab(bpe_vocab) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index dcf52b99..402346fa 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -15,8 +15,6 @@ #include #if __ANDROID_API__ >= 9 -#include - #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif @@ -33,6 +31,7 @@ #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_->SetFeatureDim(config.feat_config.feature_dim); if (config.decoding_method == "modified_beam_search") { + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + if (!config_.hotwords_file.empty()) { InitHotwords(); } @@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } #endif + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + if (!config_.hotwords_file.empty()) { InitHotwords(mgr); } @@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); std::istringstream is(hws); std::vector> current; - if (!EncodeHotwords(is, sym_, ¤t)) { + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), ¤t)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { exit(-1); } - if (!EncodeHotwords(is, sym_, &hotwords_)) { - SHERPA_ONNX_LOGE("Encode hotwords failed."); - exit(-1); + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); } hotwords_graph_ = std::make_shared(hotwords_, config_.hotwords_score); @@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { auto buf = ReadFile(mgr, config_.hotwords_file); - std::istrstream is(buf.data(), buf.size()); + std::istringstream is(std::string(buf.begin(), buf.end())); if (!is) { SHERPA_ONNX_LOGE("Open hotwords file failed: %s", @@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { exit(-1); } - if (!EncodeHotwords(is, sym_, &hotwords_)) { - SHERPA_ONNX_LOGE("Encode hotwords failed."); - exit(-1); + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); } hotwords_graph_ = std::make_shared(hotwords_, config_.hotwords_score); @@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { OnlineRecognizerConfig config_; std::vector> hotwords_; ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; std::unique_ptr model_; std::unique_ptr lm_; std::unique_ptr decoder_; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index a7fdbdff..9004d3fb 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -51,9 +51,7 @@ std::string VecToString(const std::vector &vec, std::string OnlineRecognizerResult::AsJsonString() const { std::ostringstream os; os << "{ "; - os << "\"text\": " - << "\"" << text << "\"" - << ", "; + os << "\"text\": " << "\"" << text << "\"" << ", "; os << "\"tokens\": " << VecToString(tokens) << ", "; os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; @@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "Used only when decoding_method is modified_beam_search"); po->Register( "hotwords-file", &hotwords_file, - "The file containing hotwords, one words/phrases per line, and for each" - "phrase the bpe/cjkchar are separated by a space. For example: " - "▁HE LL O ▁WORLD" - "你 好 世 界"); + "The file containing hotwords, one words/phrases per line, For example: " + "HELLO WORLD" + "你好世界"); po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 1300919b..524b2689 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) { std::string sym; int32_t id; while (is >> sym >> id) { - if (sym.size() >= 3) { - // For BPE-based models, we replace ▁ with a space - // Unicode 9601, hex 0x2581, utf8 0xe29681 - const uint8_t *p = reinterpret_cast(sym.c_str()); - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { - sym = sym.replace(0, 3, " "); - } - } - - // for byte-level BPE - // id 0 is blank, id 1 is sos/eos, id 2 is unk - if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && - sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { - std::ostringstream os; - os << std::hex << std::uppercase << (id - 3); - - if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { - uint8_t i = id - 3; - sym = std::string(&i, &i + 1); - } - } - - assert(!sym.empty()); - - // for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20, - // there is a conflict between the real byte 0x20 and ▁, so we disable - // the following check. - // - // Note: Only id2sym_ matters as we use it to convert ID to symbols. #if 0 // we disable the test here since for some multi-lingual BPE models // from NeMo, the same symbol can appear multiple times with different IDs. @@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const { return os.str(); } -const std::string &SymbolTable::operator[](int32_t id) const { - return id2sym_.at(id); +const std::string SymbolTable::operator[](int32_t id) const { + std::string sym = id2sym_.at(id); + if (sym.size() >= 3) { + // For BPE-based models, we replace ▁ with a space + // Unicode 9601, hex 0x2581, utf8 0xe29681 + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + sym = sym.replace(0, 3, " "); + } + } + + // for byte-level BPE + // id 0 is blank, id 1 is sos/eos, id 2 is unk + if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && + sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { + std::ostringstream os; + os << std::hex << std::uppercase << (id - 3); + + if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { + uint8_t i = id - 3; + sym = std::string(&i, &i + 1); + } + } + return sym; } int32_t SymbolTable::operator[](const std::string &sym) const { diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 8d0a4e98..00d7a69e 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -35,7 +35,7 @@ class SymbolTable { std::string ToString() const; /// Return the symbol corresponding to the given ID. - const std::string &operator[](int32_t id) const; + const std::string operator[](int32_t id) const; /// Return the ID corresponding to the given symbol. int32_t operator[](const std::string &sym) const; diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc new file mode 100644 index 00000000..ef07797d --- /dev/null +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -0,0 +1,152 @@ +// sherpa-onnx/csrc/text2token-test.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_onnx { + +// Please refer to +// https://github.com/pkufool/sherpa-test-data +// to download test data for testing +static const char dir[] = "/tmp/sherpa-test-data"; + +TEST(TEXT2TOKEN, TEST_cjkchar) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_cn.txt"; + + std::string tokens = oss.str(); + + if (!std::ifstream(tokens).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_cjkchar()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + + std::string text = "世界人民大团结\n中国 V S 美国"; + + std::istringstream iss(text); + + std::vector> ids; + + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids); + + std::vector> expected_ids( + {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); + EXPECT_EQ(ids, expected_ids); +} + +TEST(TEXT2TOKEN, TEST_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_en.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_en.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "HELLO WORLD\nI LOVE YOU"; + + std::istringstream iss(text); + + std::vector> ids; + + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + + std::vector> expected_ids( + {{22, 58, 24, 425}, {19, 370, 47}}); + EXPECT_EQ(ids, expected_ids); +} + +TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_mix.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_mix.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_cjkchar_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国"; + + std::istringstream iss(text); + + std::vector> ids; + + auto r = + EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids); + + std::vector> expected_ids( + {{1368, 1392, 557, 680, 275, 178, 475}, + {685, 736, 275, 178, 179, 921, 736}}); + EXPECT_EQ(ids, expected_ids); +} + +TEST(TEXT2TOKEN, TEST_bbpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_bbpe.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bbpe.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bbpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "频繁\n李鞑靼"; + + std::istringstream iss(text); + + std::vector> ids; + + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + + std::vector> expected_ids( + {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); + EXPECT_EQ(ids, expected_ids); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index 657fcbf7..6363f03c 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/utils.h" +#include #include #include #include @@ -12,15 +13,16 @@ #include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { -static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, +static bool EncodeBase(const std::vector &lines, + const SymbolTable &symbol_table, std::vector> *ids, std::vector *phrases, std::vector *scores, std::vector *thresholds) { - SHERPA_ONNX_CHECK(ids != nullptr); ids->clear(); std::vector tmp_ids; @@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, bool has_scores = false; bool has_thresholds = false; bool has_phrases = false; + bool has_oov = false; - while (std::getline(is, line)) { + for (const auto &line : lines) { float score = 0; float threshold = 0; std::string phrase = ""; std::istringstream iss(line); while (iss >> word) { - if (word.size() >= 3) { - // For BPE-based models, we replace ▁ with a space - // Unicode 9601, hex 0x2581, utf8 0xe29681 - const uint8_t *p = reinterpret_cast(word.c_str()); - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { - word = word.replace(0, 3, " "); - } - } if (symbol_table.Contains(word)) { int32_t id = symbol_table[word]; tmp_ids.push_back(id); @@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, "Cannot find ID for token %s at line: %s. (Hint: words on " "the same line are separated by spaces)", word.c_str(), line.c_str()); - return false; + has_oov = true; + break; } } } @@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, thresholds->clear(); } } - return true; + return !has_oov; } -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, std::vector> *hotwords) { - return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr); + std::vector lines; + std::string line; + std::string word; + + while (std::getline(is, line)) { + std::string score; + std::string phrase; + + std::ostringstream oss; + std::istringstream iss(line); + while (iss >> word) { + switch (word[0]) { + case ':': // boosting score for current keyword + score = word; + break; + default: + if (!score.empty()) { + SHERPA_ONNX_LOGE( + "Boosting score should be put after the words/phrase, given " + "%s.", + line.c_str()); + return false; + } + oss << " " << word; + break; + } + } + phrase = oss.str().substr(1); + std::istringstream piss(phrase); + oss.clear(); + oss.str(""); + while (piss >> word) { + if (modeling_unit == "cjkchar") { + for (const auto &w : SplitUtf8(word)) { + oss << " " << w; + } + } else if (modeling_unit == "bpe") { + std::vector bpes; + bpe_encoder->Encode(word, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + if (modeling_unit != "cjkchar+bpe") { + SHERPA_ONNX_LOGE( + "modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, " + "given " + "%s", + modeling_unit.c_str()); + exit(-1); + } + for (const auto &w : SplitUtf8(word)) { + if (isalpha(w[0])) { + std::vector bpes; + bpe_encoder->Encode(w, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + oss << " " << w; + } + } + } + } + std::string encoded_phrase = oss.str().substr(1); + oss.clear(); + oss.str(""); + oss << encoded_phrase; + if (!score.empty()) { + oss << " " << score; + } + lines.push_back(oss.str()); + } + return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr); } bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, @@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, std::vector *keywords, std::vector *boost_scores, std::vector *threshold) { - return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores, + std::vector lines; + std::string line; + while (std::getline(is, line)) { + lines.push_back(line); + } + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores, threshold); } diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index 9842bbe0..a3189a20 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -8,6 +8,7 @@ #include #include "sherpa-onnx/csrc/symbol-table.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -25,7 +26,9 @@ namespace sherpa_onnx { * @return If all the symbols from ``is`` are in the symbol_table, returns true * otherwise returns false. */ -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder_, std::vector> *hotwords_id); /* Encode the keywords in an input stream to be tokens ids. diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 16ab5167..9ad8defd 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + // transducer fid = env->GetFieldID(model_config_cls, "transducer", "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index ce0d562f..2c59c6df 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + //---------- rnn lm model config ---------- fid = env->GetFieldID(cls, "lmConfig", "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 3b468997..e7f72884 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -40,6 +40,8 @@ data class OfflineModelConfig( var provider: String = "cpu", var modelType: String = "", var tokens: String, + var modelingUnit: String = "", + var bpeVocab: String = "", ) data class OfflineRecognizerConfig( diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index 02ad5e5e..e78fb654 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -43,6 +43,8 @@ data class OnlineModelConfig( var debug: Boolean = false, var provider: String = "cpu", var modelType: String = "", + var modelingUnit: String = "", + var bpeVocab: String = "", ) data class OnlineLMConfig( diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index f8a46a3c..3fc3b34c 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineTdnnModelConfig &, const OfflineZipformerCtcModelConfig &, const OfflineWenetCtcModelConfig &, const std::string &, - int32_t, bool, const std::string &, const std::string &>(), + int32_t, bool, const std::string &, const std::string &, + const std::string &, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(), py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), @@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu", py::arg("model_type") = "") + py::arg("provider") = "cpu", py::arg("model_type") = "", + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) @@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) + .def_readwrite("modeling_unit", &PyClass::modeling_unit) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 7da0089c..d6db809b 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineZipformer2CtcModelConfig &, const OnlineNeMoCtcModelConfig &, const std::string &, int32_t, int32_t, bool, const std::string &, + const std::string &, const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), @@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) { py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("provider") = "cpu", - py::arg("model_type") = "") + py::arg("model_type") = "", py::arg("modeling_unit") = "", + py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) @@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) { .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) + .def_readwrite("modeling_unit", &PyClass::modeling_unit) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 9bf5d18b..87c5132d 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -49,6 +49,8 @@ class OfflineRecognizer(object): hotwords_file: str = "", hotwords_score: float = 1.5, blank_penalty: float = 0.0, + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", debug: bool = False, provider: str = "cpu", model_type: str = "transducer", @@ -91,6 +93,16 @@ class OfflineRecognizer(object): hotwords_file is given with modified_beam_search as decoding method. blank_penalty: The penalty applied on blank symbol during decoding. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, cjkchar, + cjkchar+bpe, etc. Currently, it is needed only when hotwords are + provided, we need it to encode the hotwords into token sequence. + and the modeling unit is bpe or cjkchar+bpe. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when hotwords provided debug: True to show debug messages. provider: @@ -107,6 +119,8 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, model_type=model_type, ) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 36fb6682..97f7472b 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -58,6 +58,8 @@ class OnlineRecognizer(object): hotwords_file: str = "", provider: str = "cpu", model_type: str = "", + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", lm: str = "", lm_scale: float = 0.1, temperature_scale: float = 2.0, @@ -136,6 +138,16 @@ class OnlineRecognizer(object): model_type: Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, cjkchar, + cjkchar+bpe, etc. Currently, it is needed only when hotwords are + provided, we need it to encode the hotwords into token sequence. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when hotwords provided + and the modeling unit is bpe or cjkchar+bpe. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -157,6 +169,8 @@ class OnlineRecognizer(object): num_threads=num_threads, provider=provider, model_type=model_type, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, debug=debug, )