2
.github/scripts/test-offline-ctc.sh
vendored
2
.github/scripts/test-offline-ctc.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-offline-transducer.sh
vendored
2
.github/scripts/test-offline-transducer.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-offline-tts.sh
vendored
2
.github/scripts/test-offline-tts.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-offline-whisper.sh
vendored
2
.github/scripts/test-offline-whisper.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-online-ctc.sh
vendored
2
.github/scripts/test-online-ctc.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-online-paraformer.sh
vendored
2
.github/scripts/test-online-paraformer.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-online-transducer.sh
vendored
2
.github/scripts/test-online-transducer.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
2
.github/scripts/test-python.sh
vendored
2
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
log "test online NeMo CTC"
|
||||
|
||||
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
|
||||
|
||||
@@ -8,6 +8,8 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
|
||||
@@ -234,6 +234,7 @@ endif()
|
||||
include(kaldi-native-fbank)
|
||||
include(kaldi-decoder)
|
||||
include(onnxruntime)
|
||||
include(simple-sentencepiece)
|
||||
set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR})
|
||||
message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ echo "Generate xcframework"
|
||||
|
||||
mkdir -p "build/simulator/lib"
|
||||
for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
|
||||
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a; do
|
||||
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a libssentencepiece_core.a; do
|
||||
lipo -create build/simulator_arm64/lib/${f} \
|
||||
build/simulator_x86_64/lib/${f} \
|
||||
-output build/simulator/lib/${f}
|
||||
@@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \
|
||||
build/simulator/lib/libsherpa-onnx-core.a \
|
||||
build/simulator/lib/libsherpa-onnx-fst.a \
|
||||
build/simulator/lib/libsherpa-onnx-kaldifst-core.a \
|
||||
build/simulator/lib/libkaldi-decoder-core.a
|
||||
build/simulator/lib/libkaldi-decoder-core.a \
|
||||
build/simulator/lib/libssentencepiece_core.a
|
||||
|
||||
libtool -static -o build/os64/sherpa-onnx.a \
|
||||
build/os64/lib/libkaldi-native-fbank-core.a \
|
||||
@@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \
|
||||
build/os64/lib/libsherpa-onnx-core.a \
|
||||
build/os64/lib/libsherpa-onnx-fst.a \
|
||||
build/os64/lib/libsherpa-onnx-kaldifst-core.a \
|
||||
build/os64/lib/libkaldi-decoder-core.a
|
||||
build/os64/lib/libkaldi-decoder-core.a \
|
||||
build/os64/lib/libssentencepiece_core.a
|
||||
|
||||
rm -rf sherpa-onnx.xcframework
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ echo "Generate xcframework"
|
||||
|
||||
mkdir -p "build/simulator/lib"
|
||||
for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
|
||||
libsherpa-onnx-fstfar.a \
|
||||
libsherpa-onnx-fstfar.a libssentencepiece_core.a \
|
||||
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \
|
||||
libucd.a libpiper_phonemize.a libespeak-ng.a; do
|
||||
lipo -create build/simulator_arm64/lib/${f} \
|
||||
@@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \
|
||||
build/simulator/lib/libucd.a \
|
||||
build/simulator/lib/libpiper_phonemize.a \
|
||||
build/simulator/lib/libespeak-ng.a \
|
||||
build/simulator/lib/libssentencepiece_core.a
|
||||
|
||||
libtool -static -o build/os64/sherpa-onnx.a \
|
||||
build/os64/lib/libkaldi-native-fbank-core.a \
|
||||
@@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \
|
||||
build/os64/lib/libucd.a \
|
||||
build/os64/lib/libpiper_phonemize.a \
|
||||
build/os64/lib/libespeak-ng.a \
|
||||
build/os64/lib/libssentencepiece_core.a
|
||||
|
||||
|
||||
rm -rf sherpa-onnx.xcframework
|
||||
|
||||
@@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \
|
||||
./install/lib/libkaldi-decoder-core.a \
|
||||
./install/lib/libucd.a \
|
||||
./install/lib/libpiper_phonemize.a \
|
||||
./install/lib/libespeak-ng.a
|
||||
./install/lib/libespeak-ng.a \
|
||||
./install/lib/libssentencepiece_core.a
|
||||
|
||||
63
cmake/simple-sentencepiece.cmake
Normal file
63
cmake/simple-sentencepiece.cmake
Normal file
@@ -0,0 +1,63 @@
|
||||
function(download_simple_sentencepiece)
|
||||
include(FetchContent)
|
||||
|
||||
set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
|
||||
set(simple-sentencepiece_URL2 "https://hub.nauu.cf/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
|
||||
set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download simple-sentencepiece
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz
|
||||
${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz
|
||||
/tmp/simple-sentencepiece-0.7.tar.gz
|
||||
/star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(simple-sentencepiece_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL)
|
||||
message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}")
|
||||
set(simple-sentencepiece_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_Declare(simple-sentencepiece
|
||||
URL
|
||||
${simple-sentencepiece_URL}
|
||||
${simple-sentencepiece_URL2}
|
||||
URL_HASH
|
||||
${simple-sentencepiece_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(simple-sentencepiece)
|
||||
if(NOT simple-sentencepiece_POPULATED)
|
||||
message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}")
|
||||
FetchContent_Populate(simple-sentencepiece)
|
||||
endif()
|
||||
message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}")
|
||||
add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
|
||||
target_include_directories(ssentencepiece_core
|
||||
PUBLIC
|
||||
${simple-sentencepiece_SOURCE_DIR}/
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
|
||||
install(TARGETS ssentencepiece_core DESTINATION ..)
|
||||
else()
|
||||
install(TARGETS ssentencepiece_core DESTINATION lib)
|
||||
endif()
|
||||
|
||||
if(WIN32 AND BUILD_SHARED_LIBS)
|
||||
install(TARGETS ssentencepiece_core DESTINATION bin)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
download_simple_sentencepiece()
|
||||
@@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() {
|
||||
function testOnlineAsr() {
|
||||
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
|
||||
GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
|
||||
fi
|
||||
|
||||
if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
piper_phonemize.lib;
|
||||
espeak-ng.lib;
|
||||
ucd.lib;
|
||||
ssentencepiece_core.lib;
|
||||
</SherpaOnnxLibraries>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup>
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
piper_phonemize.lib;
|
||||
espeak-ng.lib;
|
||||
ucd.lib;
|
||||
ssentencepiece_core.lib;
|
||||
</SherpaOnnxLibraries>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup>
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
piper_phonemize.lib;
|
||||
espeak-ng.lib;
|
||||
ucd.lib;
|
||||
ssentencepiece_core.lib;
|
||||
</SherpaOnnxLibraries>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup>
|
||||
|
||||
@@ -110,11 +110,9 @@ def get_args():
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The file containing hotwords, one words/phrases per line, and for each
|
||||
phrase the bpe/cjkchar are separated by a space. For example:
|
||||
|
||||
▁HE LL O ▁WORLD
|
||||
你 好 世 界
|
||||
The file containing hotwords, one words/phrases per line, like
|
||||
HELLO WORLD
|
||||
你好世界
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -128,6 +126,28 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
Used only when hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||
and modeling-unit is bpe or cjkchar+bpe.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
@@ -347,6 +367,8 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
modeling_unit=args.modeling_unit,
|
||||
bpe_vocab=args.bpe_vocab,
|
||||
blank_penalty=args.blank_penalty,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
@@ -198,11 +198,9 @@ def get_args():
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The file containing hotwords, one words/phrases per line, and for each
|
||||
phrase the bpe/cjkchar are separated by a space. For example:
|
||||
|
||||
▁HE LL O ▁WORLD
|
||||
你 好 世 界
|
||||
The file containing hotwords, one words/phrases per line, like
|
||||
HELLO WORLD
|
||||
你好世界
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -216,6 +214,28 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
Used only when hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||
and modeling-unit is bpe or cjkchar+bpe.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
@@ -302,6 +322,8 @@ def main():
|
||||
lm_scale=args.lm_scale,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
modeling_unit=args.modeling_unit,
|
||||
bpe_vocab=args.bpe_vocab,
|
||||
blank_penalty=args.blank_penalty,
|
||||
)
|
||||
elif args.zipformer2_ctc:
|
||||
|
||||
68
scripts/export_bpe_vocab.py
Executable file
68
scripts/export_bpe_vocab.py
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
print('Please run')
|
||||
print(' pip install sentencepiece')
|
||||
print('before you continue')
|
||||
raise
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="The path to the bpe model.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
model_file = args.bpe_model
|
||||
|
||||
vocab_file = model_file.replace(".model", ".vocab")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load(model_file)
|
||||
vocabs = [sp.IdToPiece(id) for id in range(sp.GetPieceSize())]
|
||||
with open(vocab_file, "w") as vfile:
|
||||
for v in vocabs:
|
||||
id = sp.PieceToId(v)
|
||||
vfile.write(f"{v}\t{sp.GetScore(id)}\n")
|
||||
print(f"Vocabulary file is written to {vocab_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -165,6 +165,7 @@ endif()
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
kaldi-native-fbank-core
|
||||
kaldi-decoder-core
|
||||
ssentencepiece_core
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_GPU)
|
||||
@@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
pad-sequence-test.cc
|
||||
slice-test.cc
|
||||
stack-test.cc
|
||||
text2token-test.cc
|
||||
transpose-test.cc
|
||||
unbind-test.cc
|
||||
utfcpp-test.cc
|
||||
|
||||
@@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||
"tdnn, zipformer2_ctc"
|
||||
"All other values lead to loading the model twice.");
|
||||
po->Register("modeling-unit", &modeling_unit,
|
||||
"The modeling unit of the model, commonly used units are bpe, "
|
||||
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
|
||||
"hotwords are provided, we need it to encode the hotwords into "
|
||||
"token sequence.");
|
||||
po->Register("bpe-vocab", &bpe_vocab,
|
||||
"The vocabulary generated by google's sentencepiece program. "
|
||||
"It is a file has two columns, one is the token, the other is "
|
||||
"the log probability, you can get it from the directory where "
|
||||
"your bpe model is generated. Only used when hotwords provided "
|
||||
"and the modeling unit is bpe or cjkchar+bpe");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
@@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!paraformer.model.empty()) {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
@@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
os << "model_type=\"" << model_type << "\", ";
|
||||
os << "modeling_unit=\"" << modeling_unit << "\", ";
|
||||
os << "bpe_vocab=\"" << bpe_vocab << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ struct OfflineModelConfig {
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
@@ -50,7 +53,9 @@ struct OfflineModelConfig {
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const OfflineWenetCtcModelConfig &wenet_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
const std::string &provider, const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
@@ -62,7 +67,9 @@ struct OfflineModelConfig {
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
model_type(model_type),
|
||||
modeling_unit(modeling_unit),
|
||||
bpe_vocab(bpe_vocab) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
@@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||
}
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords(mgr);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
@@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, symbol_table_, ¤t)) {
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
@@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
void InitHotwords(AAssetManager *mgr) {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||
|
||||
std::istringstream is(std::string(buf.begin(), buf.end()));
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
|
||||
@@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
"The file containing hotwords, one words/phrases per line, For example: "
|
||||
"HELLO WORLD"
|
||||
"你好世界");
|
||||
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
|
||||
@@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
|
||||
po->Register("modeling-unit", &modeling_unit,
|
||||
"The modeling unit of the model, commonly used units are bpe, "
|
||||
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
|
||||
"hotwords are provided, we need it to encode the hotwords into "
|
||||
"token sequence.");
|
||||
|
||||
po->Register("bpe-vocab", &bpe_vocab,
|
||||
"The vocabulary generated by google's sentencepiece program. "
|
||||
"It is a file has two columns, one is the token, the other is "
|
||||
"the log probability, you can get it from the directory where "
|
||||
"your bpe model is generated. Only used when hotwords provided "
|
||||
"and the modeling unit is bpe or cjkchar+bpe");
|
||||
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: conformer, lstm, zipformer, zipformer2, "
|
||||
@@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!modeling_unit.empty() &&
|
||||
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!paraformer.encoder.empty()) {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
@@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const {
|
||||
os << "warm_up=" << warm_up << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
os << "model_type=\"" << model_type << "\", ";
|
||||
os << "modeling_unit=\"" << modeling_unit << "\", ";
|
||||
os << "bpe_vocab=\"" << bpe_vocab << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -37,6 +37,13 @@ struct OnlineModelConfig {
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
// Valid values:
|
||||
// - cjkchar
|
||||
// - bpe
|
||||
// - cjkchar+bpe
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
OnlineModelConfig() = default;
|
||||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
||||
const OnlineParaformerModelConfig ¶former,
|
||||
@@ -45,7 +52,9 @@ struct OnlineModelConfig {
|
||||
const OnlineNeMoCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug, const std::string &provider,
|
||||
const std::string &model_type)
|
||||
const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
wenet_ctc(wenet_ctc),
|
||||
@@ -56,7 +65,9 @@ struct OnlineModelConfig {
|
||||
warm_up(warm_up),
|
||||
debug(debug),
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
model_type(model_type),
|
||||
modeling_unit(modeling_unit),
|
||||
bpe_vocab(bpe_vocab) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
@@ -33,6 +31,7 @@
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
@@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!config_.model_config.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords(mgr);
|
||||
}
|
||||
@@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, sym_, ¤t)) {
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
@@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
@@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
auto buf = ReadFile(mgr, config_.hotwords_file);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
std::istringstream is(std::string(buf.begin(), buf.end()));
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
@@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
@@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
OnlineRecognizerConfig config_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineLM> lm_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
|
||||
@@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
@@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
"The file containing hotwords, one words/phrases per line, For example: "
|
||||
"HELLO WORLD"
|
||||
"你好世界");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
|
||||
@@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
assert(!sym.empty());
|
||||
|
||||
// for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,
|
||||
// there is a conflict between the real byte 0x20 and ▁, so we disable
|
||||
// the following check.
|
||||
//
|
||||
// Note: Only id2sym_ matters as we use it to convert ID to symbols.
|
||||
#if 0
|
||||
// we disable the test here since for some multi-lingual BPE models
|
||||
// from NeMo, the same symbol can appear multiple times with different IDs.
|
||||
@@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
const std::string &SymbolTable::operator[](int32_t id) const {
|
||||
return id2sym_.at(id);
|
||||
const std::string SymbolTable::operator[](int32_t id) const {
|
||||
std::string sym = id2sym_.at(id);
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
return sym;
|
||||
}
|
||||
|
||||
int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||
|
||||
@@ -35,7 +35,7 @@ class SymbolTable {
|
||||
std::string ToString() const;
|
||||
|
||||
/// Return the symbol corresponding to the given ID.
|
||||
const std::string &operator[](int32_t id) const;
|
||||
const std::string operator[](int32_t id) const;
|
||||
/// Return the ID corresponding to the given symbol.
|
||||
int32_t operator[](const std::string &sym) const;
|
||||
|
||||
|
||||
152
sherpa-onnx/csrc/text2token-test.cc
Normal file
152
sherpa-onnx/csrc/text2token-test.cc
Normal file
@@ -0,0 +1,152 @@
|
||||
// sherpa-onnx/csrc/text2token-test.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// Please refer to
|
||||
// https://github.com/pkufool/sherpa-test-data
|
||||
// to download test data for testing
|
||||
static const char dir[] = "/tmp/sherpa-test-data";
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_cn.txt";
|
||||
|
||||
std::string tokens = oss.str();
|
||||
|
||||
if (!std::ifstream(tokens).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_cjkchar()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
|
||||
std::string text = "世界人民大团结\n中国 V S 美国";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_en.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bpe_en.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_bpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "HELLO WORLD\nI LOVE YOU";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{22, 58, 24, 425}, {19, 370, 47}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_mix.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bpe_mix.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_cjkchar_bpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r =
|
||||
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{1368, 1392, 557, 680, 275, 178, 475},
|
||||
{685, 736, 275, 178, 179, 921, 736}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
TEST(TEXT2TOKEN, TEST_bbpe) {
|
||||
std::ostringstream oss;
|
||||
oss << dir << "/text2token/tokens_bbpe.txt";
|
||||
std::string tokens = oss.str();
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << dir << "/text2token/bbpe.vocab";
|
||||
std::string bpe = oss.str();
|
||||
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No test data found, skipping TEST_bbpe()."
|
||||
"You can download the test data by: "
|
||||
"git clone https://github.com/pkufool/sherpa-test-data.git "
|
||||
"/tmp/sherpa-test-data");
|
||||
return;
|
||||
}
|
||||
|
||||
auto sym_table = SymbolTable(tokens);
|
||||
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
|
||||
|
||||
std::string text = "频繁\n李鞑靼";
|
||||
|
||||
std::istringstream iss(text);
|
||||
|
||||
std::vector<std::vector<int32_t>> ids;
|
||||
|
||||
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
|
||||
|
||||
std::vector<std::vector<int32_t>> expected_ids(
|
||||
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
|
||||
EXPECT_EQ(ids, expected_ids);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -12,15 +13,16 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
static bool EncodeBase(const std::vector<std::string> &lines,
|
||||
const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
@@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
bool has_oov = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
for (const auto &line : lines) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
word = word.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
if (symbol_table.Contains(word)) {
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
@@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
has_oov = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return !has_oov;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
std::string word;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::string score;
|
||||
std::string phrase;
|
||||
|
||||
std::ostringstream oss;
|
||||
std::istringstream iss(line);
|
||||
while (iss >> word) {
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = word;
|
||||
break;
|
||||
default:
|
||||
if (!score.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Boosting score should be put after the words/phrase, given "
|
||||
"%s.",
|
||||
line.c_str());
|
||||
return false;
|
||||
}
|
||||
oss << " " << word;
|
||||
break;
|
||||
}
|
||||
}
|
||||
phrase = oss.str().substr(1);
|
||||
std::istringstream piss(phrase);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
while (piss >> word) {
|
||||
if (modeling_unit == "cjkchar") {
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
oss << " " << w;
|
||||
}
|
||||
} else if (modeling_unit == "bpe") {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(word, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
if (modeling_unit != "cjkchar+bpe") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
|
||||
"given "
|
||||
"%s",
|
||||
modeling_unit.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
for (const auto &w : SplitUtf8(word)) {
|
||||
if (isalpha(w[0])) {
|
||||
std::vector<std::string> bpes;
|
||||
bpe_encoder->Encode(w, &bpes);
|
||||
for (const auto &bpe : bpes) {
|
||||
oss << " " << bpe;
|
||||
}
|
||||
} else {
|
||||
oss << " " << w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string encoded_phrase = oss.str().substr(1);
|
||||
oss.clear();
|
||||
oss.str("");
|
||||
oss << encoded_phrase;
|
||||
if (!score.empty()) {
|
||||
oss << " " << score;
|
||||
}
|
||||
lines.push_back(oss.str());
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
@@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
std::vector<std::string> lines;
|
||||
std::string line;
|
||||
while (std::getline(is, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -25,7 +26,9 @@ namespace sherpa_onnx {
|
||||
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
|
||||
const SymbolTable &symbol_table,
|
||||
const ssentencepiece::Ssentencepiece *bpe_encoder_,
|
||||
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||
|
||||
/* Encode the keywords in an input stream to be tokens ids.
|
||||
|
||||
@@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.modeling_unit = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.bpe_vocab = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
|
||||
|
||||
@@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.modeling_unit = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.bpe_vocab = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- rnn lm model config ----------
|
||||
fid = env->GetFieldID(cls, "lmConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
||||
|
||||
@@ -40,6 +40,8 @@ data class OfflineModelConfig(
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
var tokens: String,
|
||||
var modelingUnit: String = "",
|
||||
var bpeVocab: String = "",
|
||||
)
|
||||
|
||||
data class OfflineRecognizerConfig(
|
||||
|
||||
@@ -43,6 +43,8 @@ data class OnlineModelConfig(
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
var modelingUnit: String = "",
|
||||
var bpeVocab: String = "",
|
||||
)
|
||||
|
||||
data class OnlineLMConfig(
|
||||
|
||||
@@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
const OfflineTdnnModelConfig &,
|
||||
const OfflineZipformerCtcModelConfig &,
|
||||
const OfflineWenetCtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
int32_t, bool, const std::string &, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
@@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "",
|
||||
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
@@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def_readwrite("model_type", &PyClass::model_type)
|
||||
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
|
||||
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
const OnlineZipformer2CtcModelConfig &,
|
||||
const OnlineNeMoCtcModelConfig &, const std::string &,
|
||||
int32_t, int32_t, bool, const std::string &,
|
||||
const std::string &, const std::string &,
|
||||
const std::string &>(),
|
||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||
@@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("warm_up") = 0,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu",
|
||||
py::arg("model_type") = "")
|
||||
py::arg("model_type") = "", py::arg("modeling_unit") = "",
|
||||
py::arg("bpe_vocab") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||
@@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def_readwrite("model_type", &PyClass::model_type)
|
||||
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
|
||||
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -49,6 +49,8 @@ class OfflineRecognizer(object):
|
||||
hotwords_file: str = "",
|
||||
hotwords_score: float = 1.5,
|
||||
blank_penalty: float = 0.0,
|
||||
modeling_unit: str = "cjkchar",
|
||||
bpe_vocab: str = "",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
model_type: str = "transducer",
|
||||
@@ -91,6 +93,16 @@ class OfflineRecognizer(object):
|
||||
hotwords_file is given with modified_beam_search as decoding method.
|
||||
blank_penalty:
|
||||
The penalty applied on blank symbol during decoding.
|
||||
modeling_unit:
|
||||
The modeling unit of the model, commonly used units are bpe, cjkchar,
|
||||
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
|
||||
provided, we need it to encode the hotwords into token sequence.
|
||||
and the modeling unit is bpe or cjkchar+bpe.
|
||||
bpe_vocab:
|
||||
The vocabulary generated by google's sentencepiece program.
|
||||
It is a file has two columns, one is the token, the other is
|
||||
the log probability, you can get it from the directory where
|
||||
your bpe model is generated. Only used when hotwords provided
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -107,6 +119,8 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
modeling_unit=modeling_unit,
|
||||
bpe_vocab=bpe_vocab,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ class OnlineRecognizer(object):
|
||||
hotwords_file: str = "",
|
||||
provider: str = "cpu",
|
||||
model_type: str = "",
|
||||
modeling_unit: str = "cjkchar",
|
||||
bpe_vocab: str = "",
|
||||
lm: str = "",
|
||||
lm_scale: float = 0.1,
|
||||
temperature_scale: float = 2.0,
|
||||
@@ -136,6 +138,16 @@ class OnlineRecognizer(object):
|
||||
model_type:
|
||||
Online transducer model type. Valid values are: conformer, lstm,
|
||||
zipformer, zipformer2. All other values lead to loading the model twice.
|
||||
modeling_unit:
|
||||
The modeling unit of the model, commonly used units are bpe, cjkchar,
|
||||
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
|
||||
provided, we need it to encode the hotwords into token sequence.
|
||||
bpe_vocab:
|
||||
The vocabulary generated by google's sentencepiece program.
|
||||
It is a file has two columns, one is the token, the other is
|
||||
the log probability, you can get it from the directory where
|
||||
your bpe model is generated. Only used when hotwords provided
|
||||
and the modeling unit is bpe or cjkchar+bpe.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -157,6 +169,8 @@ class OnlineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
modeling_unit=modeling_unit,
|
||||
bpe_vocab=bpe_vocab,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user