diff --git a/.github/scripts/Main.kt b/.github/scripts/Main.kt index 8bbeb76e..49d8379e 100644 --- a/.github/scripts/Main.kt +++ b/.github/scripts/Main.kt @@ -35,7 +35,7 @@ fun main() { var objArray = WaveReader.readWave( assetManager = AssetManager(), - filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", + filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", ) var samples : FloatArray = objArray[0] as FloatArray var sampleRate : Int = objArray[1] as Int diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 6f3fea1d..9fc1107a 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -25,6 +25,7 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd time $EXE \ @@ -37,6 +38,16 @@ time $EXE \ $repo/test_wavs/1.wav \ $repo/test_wavs/8k.wav +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + rm -rf $repo log "------------------------------------------------------------" @@ -51,6 +62,7 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd time $EXE \ @@ -63,6 +75,16 @@ time $EXE \ $repo/test_wavs/1.wav \ $repo/test_wavs/8k.wav +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + rm -rf $repo log "------------------------------------------------------------" @@ -77,6 +99,7 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd time $EXE \ @@ -89,4 +112,14 @@ time $EXE \ $repo/test_wavs/2.wav \ $repo/test_wavs/8k.wav +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + rm -rf $repo diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh index 138e0f73..2a5bac87 100755 --- a/.github/scripts/test-online-transducer.sh +++ b/.github/scripts/test-online-transducer.sh @@ -25,12 +25,13 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd waves=( -$repo/test_wavs/1089-134686-0001.wav -$repo/test_wavs/1221-135766-0001.wav -$repo/test_wavs/1221-135766-0002.wav +$repo/test_wavs/0.wav +$repo/test_wavs/1.wav +$repo/test_wavs/8k.wav ) for wave in ${waves[@]}; do @@ -43,6 +44,16 @@ for wave in ${waves[@]}; do 2 done +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.int8.onnx \ + $repo/decoder-epoch-99-avg-1.int8.onnx \ + $repo/joiner-epoch-99-avg-1.int8.onnx \ + $wave \ + 2 +done + rm -rf $repo log "------------------------------------------------------------" @@ -57,12 +68,13 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd waves=( $repo/test_wavs/0.wav $repo/test_wavs/1.wav -$repo/test_wavs/2.wav +$repo/test_wavs/8k.wav ) for wave in ${waves[@]}; do @@ -75,6 +87,16 @@ for wave in ${waves[@]}; do 2 done +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-11-avg-1.int8.onnx \ + $repo/decoder-epoch-11-avg-1.int8.onnx \ + $repo/joiner-epoch-11-avg-1.int8.onnx \ + $wave \ + 2 +done + rm -rf $repo log "------------------------------------------------------------" @@ -89,12 +111,13 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd waves=( -$repo/test_wavs/1089-134686-0001.wav -$repo/test_wavs/1221-135766-0001.wav -$repo/test_wavs/1221-135766-0002.wav +$repo/test_wavs/0.wav +$repo/test_wavs/1.wav +$repo/test_wavs/8k.wav ) for wave in ${waves[@]}; do @@ -107,10 +130,22 @@ for wave in ${waves[@]}; do 2 done +# test int8 +# +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.int8.onnx \ + $repo/decoder-epoch-99-avg-1.int8.onnx \ + $repo/joiner-epoch-99-avg-1.int8.onnx \ + $wave \ + 2 +done + rm -rf $repo log "------------------------------------------------------------" -log "Run streaming Zipformer transducer (Bilingual, Chinse + English)" +log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" log "------------------------------------------------------------" repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 @@ -121,6 +156,7 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" +ls -lh *.onnx popd waves=( @@ -128,7 +164,7 @@ $repo/test_wavs/0.wav $repo/test_wavs/1.wav $repo/test_wavs/2.wav $repo/test_wavs/3.wav -$repo/test_wavs/4.wav +$repo/test_wavs/8k.wav ) for wave in ${waves[@]}; do @@ -141,6 +177,16 @@ for wave in ${waves[@]}; do 2 done +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.int8.onnx \ + $repo/decoder-epoch-99-avg-1.int8.onnx \ + $repo/joiner-epoch-99-avg-1.int8.onnx \ + $wave \ + 2 +done + # Decode a URL if [ $EXE == "sherpa-onnx-ffmpeg" ]; then time $EXE \ @@ -152,4 +198,14 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then 2 fi +if [ $EXE == "sherpa-onnx-ffmpeg" ]; then + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.int8.onnx \ + $repo/decoder-epoch-99-avg-1.int8.onnx \ + $repo/joiner-epoch-99-avg-1.int8.onnx \ + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ + 2 +fi + rm -rf $repo diff --git a/.gitignore b/.gitignore index 8704f21d..800713d2 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,8 @@ run-sherpa-onnx-offline-paraformer.sh run-sherpa-onnx-offline-transducer.sh sherpa-onnx-paraformer-zh-2023-03-28 run-offline-websocket-server-paraformer.sh +run-*int8.sh +a.sh +run-offline-websocket-client-*.sh +run-sherpa-onnx-*.sh +sherpa-onnx-zipformer-en-2023-03-30 diff --git a/sherpa-onnx/csrc/parse-options.cc b/sherpa-onnx/csrc/parse-options.cc index 54628949..34b52f2c 100644 --- a/sherpa-onnx/csrc/parse-options.cc +++ b/sherpa-onnx/csrc/parse-options.cc @@ -18,139 +18,13 @@ #include #include #include -#include -#include -#include #include "sherpa-onnx/csrc/log.h" - -#ifdef _MSC_VER -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ - _strtoi64(cur_cstr, end_cstr, 10); -#else -#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); -#endif +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { -/// Converts a string into an integer via strtoll and returns false if there was -/// any kind of problem (i.e. the string was not an integer or contained extra -/// non-whitespace junk, or the integer was too large to fit into the type it is -/// being converted into). Only sets *out if everything was OK and it returns -/// true. -template -bool ConvertStringToInteger(const std::string &str, Int *out) { - // copied from kaldi/src/util/text-util.h - static_assert(std::is_integral::value, ""); - const char *this_str = str.c_str(); - char *end = nullptr; - errno = 0; - int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); - if (end != this_str) { - while (isspace(*end)) ++end; - } - if (end == this_str || *end != '\0' || errno != 0) return false; - Int iInt = static_cast(i); - if (static_cast(iInt) != i || - (i < 0 && !std::numeric_limits::is_signed)) { - return false; - } - *out = iInt; - return true; -} - -// copied from kaldi/src/util/text-util.cc -template -class NumberIstream { - public: - explicit NumberIstream(std::istream &i) : in_(i) {} - - NumberIstream &operator>>(T &x) { - if (!in_.good()) return *this; - in_ >> x; - if (!in_.fail() && RemainderIsOnlySpaces()) return *this; - return ParseOnFail(&x); - } - - private: - std::istream &in_; - - bool RemainderIsOnlySpaces() { - if (in_.tellg() != std::istream::pos_type(-1)) { - std::string rem; - in_ >> rem; - - if (rem.find_first_not_of(' ') != std::string::npos) { - // there is not only spaces - return false; - } - } - - in_.clear(); - return true; - } - - NumberIstream &ParseOnFail(T *x) { - std::string str; - in_.clear(); - in_.seekg(0); - // If the stream is broken even before trying - // to read from it or if there are many tokens, - // it's pointless to try. - if (!(in_ >> str) || !RemainderIsOnlySpaces()) { - in_.setstate(std::ios_base::failbit); - return *this; - } - - std::unordered_map inf_nan_map; - // we'll keep just uppercase values. - inf_nan_map["INF"] = std::numeric_limits::infinity(); - inf_nan_map["+INF"] = std::numeric_limits::infinity(); - inf_nan_map["-INF"] = -std::numeric_limits::infinity(); - inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); - inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); - inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); - inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); - // MSVC - inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); - inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); - inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); - inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); - - std::transform(str.begin(), str.end(), str.begin(), ::toupper); - - if (inf_nan_map.find(str) != inf_nan_map.end()) { - *x = inf_nan_map[str]; - } else { - in_.setstate(std::ios_base::failbit); - } - - return *this; - } -}; - -/// ConvertStringToReal converts a string into either float or double -/// and returns false if there was any kind of problem (i.e. the string -/// was not a floating point number or contained extra non-whitespace junk). -/// Be careful- this function will successfully read inf's or nan's. -template -bool ConvertStringToReal(const std::string &str, T *out) { - std::istringstream iss(str); - - NumberIstream i(iss); - - i >> *out; - - if (iss.fail()) { - // Number conversion failed. - return false; - } - - return true; -} - ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { if (po != nullptr && po->other_parser_ != nullptr) { @@ -219,8 +93,8 @@ void ParseOptions::RegisterCommon(const std::string &name, T *ptr, std::string idx = name; NormalizeArgName(&idx); if (doc_map_.find(idx) != doc_map_.end()) { - SHERPA_ONNX_LOG(WARNING) - << "Registering option twice, ignoring second time: " << name; + SHERPA_ONNX_LOGE("Registering option twice, ignoring second time: %s", + name.c_str()); } else { this->RegisterSpecific(name, idx, ptr, doc, is_standard); } @@ -289,12 +163,13 @@ void ParseOptions::RegisterSpecific(const std::string &name, void ParseOptions::DisableOption(const std::string &name) { if (argv_ != nullptr) { - SHERPA_ONNX_LOG(FATAL) - << "DisableOption must not be called after calling Read()."; + SHERPA_ONNX_LOGE("DisableOption must not be called after calling Read()."); + exit(-1); } if (doc_map_.erase(name) == 0) { - SHERPA_ONNX_LOG(FATAL) << "Option " << name - << " was not registered so cannot be disabled: "; + SHERPA_ONNX_LOGE("Option %s was not registered so cannot be disabled: ", + name.c_str()); + exit(-1); } bool_map_.erase(name); int_map_.erase(name); @@ -308,7 +183,8 @@ int ParseOptions::NumArgs() const { return positional_args_.size(); } std::string ParseOptions::GetArg(int i) const { if (i < 1 || i > static_cast(positional_args_.size())) { - SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; + SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i); + exit(-1); } return positional_args_[i - 1]; @@ -460,7 +336,8 @@ int ParseOptions::Read(int argc, const char *const argv[]) { Trim(&value); if (!SetOption(key, value, has_equal_sign)) { PrintUsage(true); - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i]; + SHERPA_ONNX_LOGE("Invalid option %s", argv[i]); + exit(-1); } } else { break; @@ -481,7 +358,7 @@ int ParseOptions::Read(int argc, const char *const argv[]) { std::ostringstream strm; for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; strm << '\n'; - SHERPA_ONNX_LOG(INFO) << strm.str(); + SHERPA_ONNX_LOGE("%s", strm.str().c_str()); } return i; } @@ -522,7 +399,7 @@ void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { os << strm.str(); } - SHERPA_ONNX_LOG(INFO) << os.str(); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } void ParseOptions::PrintConfig(std::ostream &os) const { @@ -544,8 +421,9 @@ void ParseOptions::PrintConfig(std::ostream &os) const { } else if (string_map_.end() != string_map_.find(key)) { os << "'" << *string_map_.at(key) << "'"; } else { - SHERPA_ONNX_LOG(FATAL) - << "PrintConfig: unrecognized option " << key << "[code error]"; + SHERPA_ONNX_LOGE("PrintConfig: unrecognized option %s [code error]", + key.c_str()); + exit(-1); } os << '\n'; } @@ -555,7 +433,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const { void ParseOptions::ReadConfigFile(const std::string &filename) { std::ifstream is(filename.c_str(), std::ifstream::in); if (!is.good()) { - SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename; + SHERPA_ONNX_LOGE("Cannot open config file: %s", filename.c_str()); + exit(-1); } std::string line, key, value; @@ -572,12 +451,13 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { if (line.length() == 0) continue; if (line.substr(0, 2) != "--") { - SHERPA_ONNX_LOG(FATAL) - << "Reading config file " << filename << ": line " << line_number - << " does not look like a line " - << "from a Kaldi command-line program's config file: should " - << "be of the form --x=y. Note: config files intended to " - << "be sourced by shell scripts lack the '--'."; + SHERPA_ONNX_LOGE( + "Reading config file %s: line %d does not look like a line " + "from a sherpa-onnx command-line program's config file: should " + "be of the form --x=y. Note: config files intended to " + "be sourced by shell scripts lack the '--'.", + filename.c_str(), line_number); + exit(-1); } // parse option @@ -587,8 +467,9 @@ void ParseOptions::ReadConfigFile(const std::string &filename) { Trim(&value); if (!SetOption(key, value, has_equal_sign)) { PrintUsage(true); - SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file " - << filename << ": line " << line_number; + SHERPA_ONNX_LOGE("Invalid option %s in config file %s: line %d", + line.c_str(), filename.c_str(), line_number); + exit(-1); } } } @@ -605,7 +486,8 @@ void ParseOptions::SplitLongArg(const std::string &in, std::string *key, *has_equal_sign = false; } else if (pos == 2) { // we also don't allow empty keys: --=value PrintUsage(true); - SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in; + SHERPA_ONNX_LOGE("Invalid option (no key): %s", in.c_str()); + exit(-1); } else { // normal case: --option=value *key = in.substr(2, pos - 2); // 2 because starts with --. *value = in.substr(pos + 1); @@ -646,7 +528,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, bool has_equal_sign) { if (bool_map_.end() != bool_map_.find(key)) { if (has_equal_sign && value == "") { - SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "="; + SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str()); + exit(-1); } *(bool_map_[key]) = ToBool(value); } else if (int_map_.end() != int_map_.find(key)) { @@ -659,8 +542,9 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, *(double_map_[key]) = ToDouble(value); } else if (string_map_.end() != string_map_.find(key)) { if (!has_equal_sign) { - SHERPA_ONNX_LOG(FATAL) - << "Invalid option --" << key << " (option format is --x=y)."; + SHERPA_ONNX_LOGE("Invalid option --%s (option format is --x=y).", + key.c_str()); + exit(-1); } *(string_map_[key]) = value; } else { @@ -683,37 +567,46 @@ bool ParseOptions::ToBool(std::string str) const { } // if it is neither true nor false: PrintUsage(true); - SHERPA_ONNX_LOG(FATAL) - << "Invalid format for boolean argument [expected true or false]: " - << str; + SHERPA_ONNX_LOGE( + "Invalid format for boolean argument [expected true or false]: %s", + str.c_str()); + exit(-1); return false; // never reached } int32_t ParseOptions::ToInt(const std::string &str) const { int32_t ret = 0; - if (!ConvertStringToInteger(str, &ret)) - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); + exit(-1); + } return ret; } uint32_t ParseOptions::ToUint(const std::string &str) const { uint32_t ret = 0; - if (!ConvertStringToInteger(str, &ret)) - SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); + exit(-1); + } return ret; } float ParseOptions::ToFloat(const std::string &str) const { float ret; - if (!ConvertStringToReal(str, &ret)) - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + if (!ConvertStringToReal(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); + exit(-1); + } return ret; } double ParseOptions::ToDouble(const std::string &str) const { double ret; - if (!ConvertStringToReal(str, &ret)) - SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + if (!ConvertStringToReal(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); + exit(-1); + } return ret; } diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc index 44eb70c0..f54acc83 100644 --- a/sherpa-onnx/csrc/text-utils.cc +++ b/sherpa-onnx/csrc/text-utils.cc @@ -7,7 +7,11 @@ #include +#include +#include +#include #include +#include #include // This file is copied/modified from @@ -15,6 +19,102 @@ namespace sherpa_onnx { +// copied from kaldi/src/util/text-util.cc +template +class NumberIstream { + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream &operator>>(T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream &ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::unordered_map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = -std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +template bool ConvertStringToReal(const std::string &str, float *out); + +template bool ConvertStringToReal(const std::string &str, double *out); + void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector *out) { @@ -43,7 +143,9 @@ bool SplitStringToFloats(const std::string &full, const char *delim, out->resize(split.size()); for (size_t i = 0; i < split.size(); ++i) { // assume atof never fails - (*out)[i] = atof(split[i].c_str()); + F f = 0; + if (!ConvertStringToReal(split[i], &f)) return false; + (*out)[i] = f; } return true; } diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h index 6c91b805..8f1cb696 100644 --- a/sherpa-onnx/csrc/text-utils.h +++ b/sherpa-onnx/csrc/text-utils.h @@ -6,7 +6,9 @@ #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ #include +#include #include +#include #include #ifdef _MSC_VER @@ -21,6 +23,32 @@ namespace sherpa_onnx { +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, Int *out) { + // copied from kaldi/src/util/text-util.h + static_assert(std::is_integral::value, ""); + const char *this_str = str.c_str(); + char *end = nullptr; + errno = 0; + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); + if (end != this_str) { + while (isspace(*end)) ++end; + } + if (end == this_str || *end != '\0' || errno != 0) return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + /// Split a string using any of the single character delimiters. /// If omit_empty_strings == true, the output will contain any /// nonempty strings after splitting on any of the @@ -86,6 +114,10 @@ bool SplitStringToFloats(const std::string &full, const char *delim, bool omit_empty_strings, // typically false std::vector *out); +// This is defined for F = float and double. +template +bool ConvertStringToReal(const std::string &str, T *out); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_