diff --git a/sherpa-onnx/csrc/parse-options.cc b/sherpa-onnx/csrc/parse-options.cc index 33a07f32..87bc9cf0 100644 --- a/sherpa-onnx/csrc/parse-options.cc +++ b/sherpa-onnx/csrc/parse-options.cc @@ -49,6 +49,11 @@ void ParseOptions::Register(const std::string &name, int32_t *ptr, RegisterTmpl(name, ptr, doc); } +void ParseOptions::Register(const std::string &name, int64_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + void ParseOptions::Register(const std::string &name, uint32_t *ptr, const std::string &doc) { RegisterTmpl(name, ptr, doc); @@ -125,6 +130,15 @@ void ParseOptions::RegisterSpecific(const std::string &name, doc_map_[idx] = DocInfo(name, ss.str(), is_standard); } +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, int64_t *i, + const std::string &doc, bool is_standard) { + int64_map_[idx] = i; + std::ostringstream ss; + ss << doc << " (int64, default = " << *i << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + void ParseOptions::RegisterSpecific(const std::string &name, const std::string &idx, uint32_t *u, const std::string &doc, bool is_standard) { @@ -172,6 +186,7 @@ void ParseOptions::DisableOption(const std::string &name) { } bool_map_.erase(name); int_map_.erase(name); + int64_map_.erase(name); uint_map_.erase(name); float_map_.erase(name); double_map_.erase(name); @@ -411,6 +426,8 @@ void ParseOptions::PrintConfig(std::ostream &os) const { os << (*bool_map_.at(key) ? "true" : "false"); } else if (int_map_.end() != int_map_.find(key)) { os << (*int_map_.at(key)); + } else if (int64_map_.end() != int64_map_.find(key)) { + os << (*int64_map_.at(key)); } else if (uint_map_.end() != uint_map_.find(key)) { os << (*uint_map_.at(key)); } else if (float_map_.end() != float_map_.find(key)) { @@ -533,6 +550,8 @@ bool ParseOptions::SetOption(const std::string &key, const std::string &value, *(bool_map_[key]) = ToBool(value); } else if (int_map_.end() != int_map_.find(key)) { *(int_map_[key]) = ToInt(value); + } else if (int64_map_.end() != int64_map_.find(key)) { + *(int64_map_[key]) = ToInt64(value); } else if (uint_map_.end() != uint_map_.find(key)) { *(uint_map_[key]) = ToUint(value); } else if (float_map_.end() != float_map_.find(key)) { @@ -580,6 +599,15 @@ int32_t ParseOptions::ToInt(const std::string &str) const { return ret; } +int64_t ParseOptions::ToInt64(const std::string &str) const { + int64_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer int64 option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + uint32_t ParseOptions::ToUint(const std::string &str) const { uint32_t ret = 0; if (!ConvertStringToInteger(str, &ret)) { @@ -612,6 +640,8 @@ template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, const std::string &doc); template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr, const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, int64_t *ptr, + const std::string &doc); template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr, const std::string &doc); template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, @@ -627,6 +657,9 @@ template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr, template void ParseOptions::RegisterStandard(const std::string &name, int32_t *ptr, const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + int64_t *ptr, + const std::string &doc); template void ParseOptions::RegisterStandard(const std::string &name, uint32_t *ptr, const std::string &doc); @@ -646,6 +679,9 @@ template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr, template void ParseOptions::RegisterCommon(const std::string &name, int32_t *ptr, const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + int64_t *ptr, const std::string &doc, + bool is_standard); template void ParseOptions::RegisterCommon(const std::string &name, uint32_t *ptr, const std::string &doc, diff --git a/sherpa-onnx/csrc/parse-options.h b/sherpa-onnx/csrc/parse-options.h index 17ead6fe..615ee7d3 100644 --- a/sherpa-onnx/csrc/parse-options.h +++ b/sherpa-onnx/csrc/parse-options.h @@ -63,6 +63,7 @@ class ParseOptions { void Register(const std::string &name, bool *ptr, const std::string &doc); void Register(const std::string &name, int32_t *ptr, const std::string &doc); + void Register(const std::string &name, int64_t *ptr, const std::string &doc); void Register(const std::string &name, uint32_t *ptr, const std::string &doc); void Register(const std::string &name, float *ptr, const std::string &doc); void Register(const std::string &name, double *ptr, const std::string &doc); @@ -134,6 +135,9 @@ class ParseOptions { /// Register int32_t variable void RegisterSpecific(const std::string &name, const std::string &idx, int32_t *i, const std::string &doc, bool is_standard); + /// Register int64_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int64_t *i, const std::string &doc, bool is_standard); /// Register unsigned int32_t variable void RegisterSpecific(const std::string &name, const std::string &idx, uint32_t *u, const std::string &doc, bool is_standard); @@ -163,6 +167,7 @@ class ParseOptions { bool ToBool(std::string str) const; int32_t ToInt(const std::string &str) const; + int64_t ToInt64(const std::string &str) const; uint32_t ToUint(const std::string &str) const; float ToFloat(const std::string &str) const; double ToDouble(const std::string &str) const; @@ -170,6 +175,7 @@ class ParseOptions { // maps for option variables std::unordered_map bool_map_; std::unordered_map int_map_; + std::unordered_map int64_map_; std::unordered_map uint_map_; std::unordered_map float_map_; std::unordered_map double_map_;