Test int8 models (#107)
* Test int8 models * Fix displaying help messages * small fixes * Fix jni test
This commit is contained in:
2
.github/scripts/Main.kt
vendored
2
.github/scripts/Main.kt
vendored
@@ -35,7 +35,7 @@ fun main() {
|
||||
|
||||
var objArray = WaveReader.readWave(
|
||||
assetManager = AssetManager(),
|
||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
|
||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
|
||||
)
|
||||
var samples : FloatArray = objArray[0] as FloatArray
|
||||
var sampleRate : Int = objArray[1] as Int
|
||||
|
||||
33
.github/scripts/test-offline-transducer.sh
vendored
33
.github/scripts/test-offline-transducer.sh
vendored
@@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
@@ -37,6 +38,16 @@ time $EXE \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
@@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
@@ -63,6 +75,16 @@ time $EXE \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
@@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
@@ -89,4 +112,14 @@ time $EXE \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--paraformer=$repo/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
74
.github/scripts/test-online-transducer.sh
vendored
74
.github/scripts/test-online-transducer.sh
vendored
@@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
$repo/test_wavs/1221-135766-0001.wav
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
@@ -43,6 +44,16 @@ for wave in ${waves[@]}; do
|
||||
2
|
||||
done
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
@@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/2.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
@@ -75,6 +87,16 @@ for wave in ${waves[@]}; do
|
||||
2
|
||||
done
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-11-avg-1.int8.onnx \
|
||||
$repo/decoder-epoch-11-avg-1.int8.onnx \
|
||||
$repo/joiner-epoch-11-avg-1.int8.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
@@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
$repo/test_wavs/1221-135766-0001.wav
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
@@ -107,10 +130,22 @@ for wave in ${waves[@]}; do
|
||||
2
|
||||
done
|
||||
|
||||
# test int8
|
||||
#
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run streaming Zipformer transducer (Bilingual, Chinse + English)"
|
||||
log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
|
||||
@@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
waves=(
|
||||
@@ -128,7 +164,7 @@ $repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/2.wav
|
||||
$repo/test_wavs/3.wav
|
||||
$repo/test_wavs/4.wav
|
||||
$repo/test_wavs/8k.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
@@ -141,6 +177,16 @@ for wave in ${waves[@]}; do
|
||||
2
|
||||
done
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
|
||||
# Decode a URL
|
||||
if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
|
||||
time $EXE \
|
||||
@@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
|
||||
2
|
||||
fi
|
||||
|
||||
if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.int8.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.int8.onnx \
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \
|
||||
2
|
||||
fi
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh
|
||||
run-sherpa-onnx-offline-transducer.sh
|
||||
sherpa-onnx-paraformer-zh-2023-03-28
|
||||
run-offline-websocket-server-paraformer.sh
|
||||
run-*int8.sh
|
||||
a.sh
|
||||
run-offline-websocket-client-*.sh
|
||||
run-sherpa-onnx-*.sh
|
||||
sherpa-onnx-zipformer-en-2023-03-30
|
||||
|
||||
@@ -18,139 +18,13 @@
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
|
||||
_strtoi64(cur_cstr, end_cstr, 10);
|
||||
#else
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
|
||||
#endif
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Converts a string into an integer via strtoll and returns false if there was
|
||||
/// any kind of problem (i.e. the string was not an integer or contained extra
|
||||
/// non-whitespace junk, or the integer was too large to fit into the type it is
|
||||
/// being converted into). Only sets *out if everything was OK and it returns
|
||||
/// true.
|
||||
template <class Int>
|
||||
bool ConvertStringToInteger(const std::string &str, Int *out) {
|
||||
// copied from kaldi/src/util/text-util.h
|
||||
static_assert(std::is_integral<Int>::value, "");
|
||||
const char *this_str = str.c_str();
|
||||
char *end = nullptr;
|
||||
errno = 0;
|
||||
int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
|
||||
if (end != this_str) {
|
||||
while (isspace(*end)) ++end;
|
||||
}
|
||||
if (end == this_str || *end != '\0' || errno != 0) return false;
|
||||
Int iInt = static_cast<Int>(i);
|
||||
if (static_cast<int64_t>(iInt) != i ||
|
||||
(i < 0 && !std::numeric_limits<Int>::is_signed)) {
|
||||
return false;
|
||||
}
|
||||
*out = iInt;
|
||||
return true;
|
||||
}
|
||||
|
||||
// copied from kaldi/src/util/text-util.cc
|
||||
template <class T>
|
||||
class NumberIstream {
|
||||
public:
|
||||
explicit NumberIstream(std::istream &i) : in_(i) {}
|
||||
|
||||
NumberIstream &operator>>(T &x) {
|
||||
if (!in_.good()) return *this;
|
||||
in_ >> x;
|
||||
if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
|
||||
return ParseOnFail(&x);
|
||||
}
|
||||
|
||||
private:
|
||||
std::istream &in_;
|
||||
|
||||
bool RemainderIsOnlySpaces() {
|
||||
if (in_.tellg() != std::istream::pos_type(-1)) {
|
||||
std::string rem;
|
||||
in_ >> rem;
|
||||
|
||||
if (rem.find_first_not_of(' ') != std::string::npos) {
|
||||
// there is not only spaces
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
in_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
NumberIstream &ParseOnFail(T *x) {
|
||||
std::string str;
|
||||
in_.clear();
|
||||
in_.seekg(0);
|
||||
// If the stream is broken even before trying
|
||||
// to read from it or if there are many tokens,
|
||||
// it's pointless to try.
|
||||
if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, T> inf_nan_map;
|
||||
// we'll keep just uppercase values.
|
||||
inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
// MSVC
|
||||
inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
|
||||
|
||||
if (inf_nan_map.find(str) != inf_nan_map.end()) {
|
||||
*x = inf_nan_map[str];
|
||||
} else {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// ConvertStringToReal converts a string into either float or double
|
||||
/// and returns false if there was any kind of problem (i.e. the string
|
||||
/// was not a floating point number or contained extra non-whitespace junk).
|
||||
/// Be careful- this function will successfully read inf's or nan's.
|
||||
template <typename T>
|
||||
bool ConvertStringToReal(const std::string &str, T *out) {
|
||||
std::istringstream iss(str);
|
||||
|
||||
NumberIstream<T> i(iss);
|
||||
|
||||
i >> *out;
|
||||
|
||||
if (iss.fail()) {
|
||||
// Number conversion failed.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
|
||||
: print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) {
|
||||
if (po != nullptr && po->other_parser_ != nullptr) {
|
||||
@@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr,
|
||||
std::string idx = name;
|
||||
NormalizeArgName(&idx);
|
||||
if (doc_map_.find(idx) != doc_map_.end()) {
|
||||
SHERPA_ONNX_LOG(WARNING)
|
||||
<< "Registering option twice, ignoring second time: " << name;
|
||||
SHERPA_ONNX_LOGE("Registering option twice, ignoring second time: %s",
|
||||
name.c_str());
|
||||
} else {
|
||||
this->RegisterSpecific(name, idx, ptr, doc, is_standard);
|
||||
}
|
||||
@@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
|
||||
void ParseOptions::DisableOption(const std::string &name) {
|
||||
if (argv_ != nullptr) {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "DisableOption must not be called after calling Read().";
|
||||
SHERPA_ONNX_LOGE("DisableOption must not be called after calling Read().");
|
||||
exit(-1);
|
||||
}
|
||||
if (doc_map_.erase(name) == 0) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Option " << name
|
||||
<< " was not registered so cannot be disabled: ";
|
||||
SHERPA_ONNX_LOGE("Option %s was not registered so cannot be disabled: ",
|
||||
name.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
bool_map_.erase(name);
|
||||
int_map_.erase(name);
|
||||
@@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); }
|
||||
|
||||
std::string ParseOptions::GetArg(int i) const {
|
||||
if (i < 1 || i > static_cast<int>(positional_args_.size())) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i;
|
||||
SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return positional_args_[i - 1];
|
||||
@@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
|
||||
Trim(&value);
|
||||
if (!SetOption(key, value, has_equal_sign)) {
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i];
|
||||
SHERPA_ONNX_LOGE("Invalid option %s", argv[i]);
|
||||
exit(-1);
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
@@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) {
|
||||
std::ostringstream strm;
|
||||
for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
|
||||
strm << '\n';
|
||||
SHERPA_ONNX_LOG(INFO) << strm.str();
|
||||
SHERPA_ONNX_LOGE("%s", strm.str().c_str());
|
||||
}
|
||||
return i;
|
||||
}
|
||||
@@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
|
||||
os << strm.str();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(INFO) << os.str();
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
void ParseOptions::PrintConfig(std::ostream &os) const {
|
||||
@@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const {
|
||||
} else if (string_map_.end() != string_map_.find(key)) {
|
||||
os << "'" << *string_map_.at(key) << "'";
|
||||
} else {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "PrintConfig: unrecognized option " << key << "[code error]";
|
||||
SHERPA_ONNX_LOGE("PrintConfig: unrecognized option %s [code error]",
|
||||
key.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
@@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const {
|
||||
void ParseOptions::ReadConfigFile(const std::string &filename) {
|
||||
std::ifstream is(filename.c_str(), std::ifstream::in);
|
||||
if (!is.good()) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename;
|
||||
SHERPA_ONNX_LOGE("Cannot open config file: %s", filename.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string line, key, value;
|
||||
@@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
|
||||
if (line.length() == 0) continue;
|
||||
|
||||
if (line.substr(0, 2) != "--") {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Reading config file " << filename << ": line " << line_number
|
||||
<< " does not look like a line "
|
||||
<< "from a Kaldi command-line program's config file: should "
|
||||
<< "be of the form --x=y. Note: config files intended to "
|
||||
<< "be sourced by shell scripts lack the '--'.";
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Reading config file %s: line %d does not look like a line "
|
||||
"from a sherpa-onnx command-line program's config file: should "
|
||||
"be of the form --x=y. Note: config files intended to "
|
||||
"be sourced by shell scripts lack the '--'.",
|
||||
filename.c_str(), line_number);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// parse option
|
||||
@@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) {
|
||||
Trim(&value);
|
||||
if (!SetOption(key, value, has_equal_sign)) {
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file "
|
||||
<< filename << ": line " << line_number;
|
||||
SHERPA_ONNX_LOGE("Invalid option %s in config file %s: line %d",
|
||||
line.c_str(), filename.c_str(), line_number);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key,
|
||||
*has_equal_sign = false;
|
||||
} else if (pos == 2) { // we also don't allow empty keys: --=value
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in;
|
||||
SHERPA_ONNX_LOGE("Invalid option (no key): %s", in.c_str());
|
||||
exit(-1);
|
||||
} else { // normal case: --option=value
|
||||
*key = in.substr(2, pos - 2); // 2 because starts with --.
|
||||
*value = in.substr(pos + 1);
|
||||
@@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value,
|
||||
bool has_equal_sign) {
|
||||
if (bool_map_.end() != bool_map_.find(key)) {
|
||||
if (has_equal_sign && value == "") {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "=";
|
||||
SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
*(bool_map_[key]) = ToBool(value);
|
||||
} else if (int_map_.end() != int_map_.find(key)) {
|
||||
@@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value,
|
||||
*(double_map_[key]) = ToDouble(value);
|
||||
} else if (string_map_.end() != string_map_.find(key)) {
|
||||
if (!has_equal_sign) {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Invalid option --" << key << " (option format is --x=y).";
|
||||
SHERPA_ONNX_LOGE("Invalid option --%s (option format is --x=y).",
|
||||
key.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
*(string_map_[key]) = value;
|
||||
} else {
|
||||
@@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const {
|
||||
}
|
||||
// if it is neither true nor false:
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Invalid format for boolean argument [expected true or false]: "
|
||||
<< str;
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid format for boolean argument [expected true or false]: %s",
|
||||
str.c_str());
|
||||
exit(-1);
|
||||
return false; // never reached
|
||||
}
|
||||
|
||||
int32_t ParseOptions::ToInt(const std::string &str) const {
|
||||
int32_t ret = 0;
|
||||
if (!ConvertStringToInteger(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
|
||||
if (!ConvertStringToInteger(str, &ret)) {
|
||||
SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint32_t ParseOptions::ToUint(const std::string &str) const {
|
||||
uint32_t ret = 0;
|
||||
if (!ConvertStringToInteger(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
|
||||
if (!ConvertStringToInteger(str, &ret)) {
|
||||
SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
float ParseOptions::ToFloat(const std::string &str) const {
|
||||
float ret;
|
||||
if (!ConvertStringToReal(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
|
||||
if (!ConvertStringToReal(str, &ret)) {
|
||||
SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
double ParseOptions::ToDouble(const std::string &str) const {
|
||||
double ret;
|
||||
if (!ConvertStringToReal(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
|
||||
if (!ConvertStringToReal(str, &ret)) {
|
||||
SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,11 @@
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// This file is copied/modified from
|
||||
@@ -15,6 +19,102 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// copied from kaldi/src/util/text-util.cc
|
||||
template <class T>
|
||||
class NumberIstream {
|
||||
public:
|
||||
explicit NumberIstream(std::istream &i) : in_(i) {}
|
||||
|
||||
NumberIstream &operator>>(T &x) {
|
||||
if (!in_.good()) return *this;
|
||||
in_ >> x;
|
||||
if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
|
||||
return ParseOnFail(&x);
|
||||
}
|
||||
|
||||
private:
|
||||
std::istream &in_;
|
||||
|
||||
bool RemainderIsOnlySpaces() {
|
||||
if (in_.tellg() != std::istream::pos_type(-1)) {
|
||||
std::string rem;
|
||||
in_ >> rem;
|
||||
|
||||
if (rem.find_first_not_of(' ') != std::string::npos) {
|
||||
// there is not only spaces
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
in_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
NumberIstream &ParseOnFail(T *x) {
|
||||
std::string str;
|
||||
in_.clear();
|
||||
in_.seekg(0);
|
||||
// If the stream is broken even before trying
|
||||
// to read from it or if there are many tokens,
|
||||
// it's pointless to try.
|
||||
if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, T> inf_nan_map;
|
||||
// we'll keep just uppercase values.
|
||||
inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
// MSVC
|
||||
inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
|
||||
|
||||
if (inf_nan_map.find(str) != inf_nan_map.end()) {
|
||||
*x = inf_nan_map[str];
|
||||
} else {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// ConvertStringToReal converts a string into either float or double
|
||||
/// and returns false if there was any kind of problem (i.e. the string
|
||||
/// was not a floating point number or contained extra non-whitespace junk).
|
||||
/// Be careful- this function will successfully read inf's or nan's.
|
||||
template <typename T>
|
||||
bool ConvertStringToReal(const std::string &str, T *out) {
|
||||
std::istringstream iss(str);
|
||||
|
||||
NumberIstream<T> i(iss);
|
||||
|
||||
i >> *out;
|
||||
|
||||
if (iss.fail()) {
|
||||
// Number conversion failed.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template bool ConvertStringToReal<float>(const std::string &str, float *out);
|
||||
|
||||
template bool ConvertStringToReal<double>(const std::string &str, double *out);
|
||||
|
||||
void SplitStringToVector(const std::string &full, const char *delim,
|
||||
bool omit_empty_strings,
|
||||
std::vector<std::string> *out) {
|
||||
@@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
|
||||
out->resize(split.size());
|
||||
for (size_t i = 0; i < split.size(); ++i) {
|
||||
// assume atof never fails
|
||||
(*out)[i] = atof(split[i].c_str());
|
||||
F f = 0;
|
||||
if (!ConvertStringToReal(split[i], &f)) return false;
|
||||
(*out)[i] = f;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -21,6 +23,32 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Converts a string into an integer via strtoll and returns false if there was
|
||||
/// any kind of problem (i.e. the string was not an integer or contained extra
|
||||
/// non-whitespace junk, or the integer was too large to fit into the type it is
|
||||
/// being converted into). Only sets *out if everything was OK and it returns
|
||||
/// true.
|
||||
template <class Int>
|
||||
bool ConvertStringToInteger(const std::string &str, Int *out) {
|
||||
// copied from kaldi/src/util/text-util.h
|
||||
static_assert(std::is_integral<Int>::value, "");
|
||||
const char *this_str = str.c_str();
|
||||
char *end = nullptr;
|
||||
errno = 0;
|
||||
int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
|
||||
if (end != this_str) {
|
||||
while (isspace(*end)) ++end;
|
||||
}
|
||||
if (end == this_str || *end != '\0' || errno != 0) return false;
|
||||
Int iInt = static_cast<Int>(i);
|
||||
if (static_cast<int64_t>(iInt) != i ||
|
||||
(i < 0 && !std::numeric_limits<Int>::is_signed)) {
|
||||
return false;
|
||||
}
|
||||
*out = iInt;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Split a string using any of the single character delimiters.
|
||||
/// If omit_empty_strings == true, the output will contain any
|
||||
/// nonempty strings after splitting on any of the
|
||||
@@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim,
|
||||
bool omit_empty_strings, // typically false
|
||||
std::vector<F> *out);
|
||||
|
||||
// This is defined for F = float and double.
|
||||
template <typename T>
|
||||
bool ConvertStringToReal(const std::string &str, T *out);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user