diff --git a/.gitignore b/.gitignore index 1de2fb4a..0f34dac8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,11 @@ build onnxruntime-* icefall-* run.sh -sherpa-onnx-* __pycache__ dist/ sherpa_onnx.egg-info/ .DS_Store build-aarch64-linux-gnu +sherpa-onnx-streaming-zipformer-* +sherpa-onnx-lstm-en-* +sherpa-onnx-lstm-zh-* diff --git a/CMakeLists.txt b/CMakeLists.txt index 9ae83e32..19a97017 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ endif() option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") +message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") +message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_CXX_EXTENSIONS OFF) @@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA) add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1) endif() +check_include_file_cxx(cxxabi.h SHERPA_ONNX_HAVE_CXXABI_H) +check_include_file_cxx(execinfo.h SHERPA_ONNX_HAVE_EXECINFO_H) + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py new file mode 100644 index 00000000..4d2992ac --- /dev/null +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# with endpoint detection. +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +import sys + +try: + import sounddevice as sd +except ImportError as e: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_onnx + + +def create_recognizer(): + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_onnx.OnlineRecognizer( + tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt", + encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx", + decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx", + joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx", + num_threads=4, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + ) + return recognizer + + +def main(): + print("Started! Please speak") + recognizer = create_recognizer() + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + last_result = "" + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print(f"{segment_id}: {result}") + + if result and is_endpoint: + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 801f78f4..4a9c007c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,7 +1,8 @@ include_directories(${CMAKE_SOURCE_DIR}) -add_library(sherpa-onnx-core +set(sources cat.cc + endpoint.cc features.cc online-lstm-transducer-model.cc online-recognizer.cc @@ -11,6 +12,7 @@ add_library(sherpa-onnx-core online-transducer-model.cc online-zipformer-transducer-model.cc onnx-utils.cc + parse-options.cc resample.cc symbol-table.cc text-utils.cc @@ -18,11 +20,29 @@ add_library(sherpa-onnx-core wave-reader.cc ) +if(SHERPA_ONNX_ENABLE_CHECK) + list(APPEND sources log.cc) +endif() + +add_library(sherpa-onnx-core ${sources}) + target_link_libraries(sherpa-onnx-core onnxruntime kaldi-native-fbank-core ) +if(SHERPA_ONNX_ENABLE_CHECK) + target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) + + if(SHERPA_ONNX_HAVE_EXECINFO_H) + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_EXECINFO_H=1) + endif() + + if(SHERPA_ONNX_HAVE_CXXABI_H) + target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_CXXABI_H=1) + endif() +endif() + add_executable(sherpa-onnx sherpa-onnx.cc) target_link_libraries(sherpa-onnx sherpa-onnx-core) diff --git a/sherpa-onnx/csrc/endpoint.cc b/sherpa-onnx/csrc/endpoint.cc new file mode 100644 index 00000000..3a9a424c --- /dev/null +++ b/sherpa-onnx/csrc/endpoint.cc @@ -0,0 +1,91 @@ +// sherpa-onnx/csrc/endpoint.cc +// +// Copyright (c) 2022 (authors: Pingfeng Luo) +// 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/endpoint.h" + +#include + +#include "sherpa-onnx/csrc/log.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +static bool RuleActivated(const EndpointRule &rule, + const std::string &rule_name, float trailing_silence, + float utterance_length) { + bool contain_nonsilence = utterance_length > trailing_silence; + bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) && + trailing_silence >= rule.min_trailing_silence && + utterance_length >= rule.min_utterance_length; + if (ans) { + SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: " + << (contain_nonsilence ? "true" : "false") << ',' + << trailing_silence << ',' << utterance_length; + } + return ans; +} + +static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule, + const std::string &rule_name) { + po->Register( + rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence, + "If True, for this endpointing " + rule_name + + " to apply there must be nonsilence in the best-path traceback. " + "For decoding, a non-blank token is considered as non-silence"); + po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence, + "This endpointing " + rule_name + + " requires duration of trailing silence in seconds) to " + "be >= this value."); + po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length, + "This endpointing " + rule_name + + " requires utterance-length (in seconds) to be >= this " + "value."); +} + +std::string EndpointRule::ToString() const { + std::ostringstream os; + + os << "EndpointRule("; + os << "must_contain_nonsilence=" + << (must_contain_nonsilence ? "True" : "False") << ", "; + os << "min_trailing_silence=" << min_trailing_silence << ", "; + os << "min_utterance_length=" << min_utterance_length << ")"; + + return os.str(); +} + +void EndpointConfig::Register(ParseOptions *po) { + RegisterEndpointRule(po, &rule1, "rule1"); + RegisterEndpointRule(po, &rule2, "rule2"); + RegisterEndpointRule(po, &rule3, "rule3"); +} + +std::string EndpointConfig::ToString() const { + std::ostringstream os; + + os << "EndpointConfig("; + os << "rule1=" << rule1.ToString() << ", "; + os << "rule2=" << rule2.ToString() << ", "; + os << "rule3=" << rule3.ToString() << ")"; + + return os.str(); +} + +bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames, + float frame_shift_in_seconds) const { + float utterance_length = num_frames_decoded * frame_shift_in_seconds; + float trailing_silence = trailing_silence_frames * frame_shift_in_seconds; + if (RuleActivated(config_.rule1, "rule1", trailing_silence, + utterance_length) || + RuleActivated(config_.rule2, "rule2", trailing_silence, + utterance_length) || + RuleActivated(config_.rule3, "rule3", trailing_silence, + utterance_length)) { + return true; + } + return false; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/endpoint.h b/sherpa-onnx/csrc/endpoint.h new file mode 100644 index 00000000..73995840 --- /dev/null +++ b/sherpa-onnx/csrc/endpoint.h @@ -0,0 +1,76 @@ +// sherpa-onnx/csrc/endpoint.h +// +// Copyright (c) 2022 (authors: Pingfeng Luo) +// 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_ +#define SHERPA_ONNX_CSRC_ENDPOINT_H_ + +#include +#include + +namespace sherpa_onnx { + +struct EndpointRule { + // If True, for this endpointing rule to apply there must + // be nonsilence in the best-path traceback. + // For decoding, a non-blank token is considered as non-silence + bool must_contain_nonsilence = true; + // This endpointing rule requires duration of trailing silence + // (in seconds) to be >= this value. + float min_trailing_silence = 2.0; + // This endpointing rule requires utterance-length (in seconds) + // to be >= this value. + float min_utterance_length = 0.0f; + + EndpointRule() = default; + + EndpointRule(bool must_contain_nonsilence, float min_trailing_silence, + float min_utterance_length) + : must_contain_nonsilence(must_contain_nonsilence), + min_trailing_silence(min_trailing_silence), + min_utterance_length(min_utterance_length) {} + + std::string ToString() const; +}; + +class ParseOptions; + +struct EndpointConfig { + // For default setting, + // rule1 times out after 2.4 seconds of silence, even if we decoded nothing. + // rule2 times out after 1.2 seconds of silence after decoding something. + // rule3 times out after the utterance is 20 seconds long, regardless of + // anything else. + EndpointRule rule1; + EndpointRule rule2; + EndpointRule rule3; + + void Register(ParseOptions *po); + + EndpointConfig() + : rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {} + + EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2, + const EndpointRule &rule3) + : rule1(rule1), rule2(rule2), rule3(rule3) {} + + std::string ToString() const; +}; + +class Endpoint { + public: + explicit Endpoint(const EndpointConfig &config) : config_(config) {} + + /// This function returns true if this set of endpointing rules thinks we + /// should terminate decoding. + bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames, + float frame_shift_in_seconds) const; + + private: + EndpointConfig config_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_ diff --git a/sherpa-onnx/csrc/log.cc b/sherpa-onnx/csrc/log.cc new file mode 100644 index 00000000..df3018b9 --- /dev/null +++ b/sherpa-onnx/csrc/log.cc @@ -0,0 +1,122 @@ +// sherpa-onnx/csrc/log.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/log.h" + +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H +#include // To get stack trace in error messages. +#ifdef SHERPA_ONNX_HAVE_CXXABI_H +#include // For name demangling. +// Useful to decode the stack trace, but only used if we have execinfo.h +#endif // SHERPA_ONNX_HAVE_CXXABI_H +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + +#include + +#include +#include +#include + +namespace sherpa_onnx { + +std::string GetDateTimeStr() { + std::ostringstream os; + std::time_t t = std::time(nullptr); + std::tm tm = *std::localtime(&t); + os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss + return os.str(); +} + +static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin, + std::size_t *end) { + // Find the first '_' with leading ' ' or '('. + *begin = std::string::npos; + for (std::size_t i = 1; i < trace_name.size(); ++i) { + if (trace_name[i] != '_') { + continue; + } + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') { + *begin = i; + break; + } + } + if (*begin == std::string::npos) { + return false; + } + *end = trace_name.find_first_of(" +", *begin); + return *end != std::string::npos; +} + +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H +static std::string Demangle(const std::string &trace_name) { +#ifndef SHERPA_ONNX_HAVE_CXXABI_H + return trace_name; +#else // SHERPA_ONNX_HAVE_CXXABI_H + // Try demangle the symbol. We are trying to support the following formats + // produced by different platforms: + // + // Linux: + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] + // + // Mac: + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813 + // + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and + // demangle it info a readable name like kaldi::UnitTextError. + std::size_t begin, end; + if (!LocateSymbolRange(trace_name, &begin, &end)) { + return trace_name; + } + std::string symbol = trace_name.substr(begin, end - begin); + int status; + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status); + if (status == 0 && demangled_name != nullptr) { + symbol = demangled_name; + free(demangled_name); + } + return trace_name.substr(0, begin) + symbol + + trace_name.substr(end, std::string::npos); +#endif // SHERPA_ONNX_HAVE_CXXABI_H +} +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + +std::string GetStackTrace() { + std::string ans; +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H + constexpr const std::size_t kMaxTraceSize = 50; + constexpr const std::size_t kMaxTracePrint = 50; // Must be even. + // Buffer for the trace. + void *trace[kMaxTraceSize]; + // Get the trace. + std::size_t size = backtrace(trace, kMaxTraceSize); + // Get the trace symbols. + char **trace_symbol = backtrace_symbols(trace, size); + if (trace_symbol == nullptr) return ans; + + // Compose a human-readable backtrace string. + ans += "[ Stack-Trace: ]\n"; + if (size <= kMaxTracePrint) { + for (std::size_t i = 0; i < size; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + } else { // Print out first+last (e.g.) 5. + for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + ans += ".\n.\n.\n"; + for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + if (size == kMaxTraceSize) + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. + } + + // We must free the array of pointers allocated by backtrace_symbols(), + // but not the strings themselves. + free(trace_symbol); +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/log.h b/sherpa-onnx/csrc/log.h new file mode 100644 index 00000000..d2d29fe0 --- /dev/null +++ b/sherpa-onnx/csrc/log.h @@ -0,0 +1,378 @@ +// sherpa-onnx/csrc/log.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_LOG_H_ +#define SHERPA_ONNX_CSRC_LOG_H_ + +#include + +#include // NOLINT +#include +#include + +namespace sherpa_onnx { + +#if SHERPA_ONNX_ENABLE_CHECK + +#if defined(NDEBUG) +constexpr bool kDisableDebug = true; +#else +constexpr bool kDisableDebug = false; +#endif + +enum class LogLevel { + kTrace = 0, + kDebug = 1, + kInfo = 2, + kWarning = 3, + kError = 4, + kFatal = 5, // print message and abort the program +}; + +// They are used in SHERPA_ONNX_LOG(xxx), so their names +// do not follow the google c++ code style +// +// You can use them in the following way: +// +// SHERPA_ONNX_LOG(TRACE) << "some message"; +// SHERPA_ONNX_LOG(DEBUG) << "some message"; +#ifndef _MSC_VER +constexpr LogLevel TRACE = LogLevel::kTrace; +constexpr LogLevel DEBUG = LogLevel::kDebug; +constexpr LogLevel INFO = LogLevel::kInfo; +constexpr LogLevel WARNING = LogLevel::kWarning; +constexpr LogLevel ERROR = LogLevel::kError; +constexpr LogLevel FATAL = LogLevel::kFatal; +#else +#define TRACE LogLevel::kTrace +#define DEBUG LogLevel::kDebug +#define INFO LogLevel::kInfo +#define WARNING LogLevel::kWarning +#define ERROR LogLevel::kError +#define FATAL LogLevel::kFatal +#endif + +std::string GetStackTrace(); + +/* Return the current log level. + + + If the current log level is TRACE, then all logged messages are printed out. + + If the current log level is DEBUG, log messages with "TRACE" level are not + shown and all other levels are printed out. + + Similarly, if the current log level is INFO, log message with "TRACE" and + "DEBUG" are not shown and all other levels are printed out. + + If it is FATAL, then only FATAL messages are shown. + */ +inline LogLevel GetCurrentLogLevel() { + static LogLevel log_level = INFO; + static std::once_flag init_flag; + std::call_once(init_flag, []() { + const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL"); + if (env_log_level == nullptr) return; + + std::string s = env_log_level; + if (s == "TRACE") + log_level = TRACE; + else if (s == "DEBUG") + log_level = DEBUG; + else if (s == "INFO") + log_level = INFO; + else if (s == "WARNING") + log_level = WARNING; + else if (s == "ERROR") + log_level = ERROR; + else if (s == "FATAL") + log_level = FATAL; + else + fprintf(stderr, + "Unknown SHERPA_ONNX_LOG_LEVEL: %s" + "\nSupported values are: " + "TRACE, DEBUG, INFO, WARNING, ERROR, FATAL", + s.c_str()); + }); + return log_level; +} + +inline bool EnableAbort() { + static std::once_flag init_flag; + static bool enable_abort = false; + std::call_once(init_flag, []() { + enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr); + }); + return enable_abort; +} + +class Logger { + public: + Logger(const char *filename, const char *func_name, uint32_t line_num, + LogLevel level) + : filename_(filename), + func_name_(func_name), + line_num_(line_num), + level_(level) { + cur_level_ = GetCurrentLogLevel(); + switch (level) { + case TRACE: + if (cur_level_ <= TRACE) fprintf(stderr, "[T] "); + break; + case DEBUG: + if (cur_level_ <= DEBUG) fprintf(stderr, "[D] "); + break; + case INFO: + if (cur_level_ <= INFO) fprintf(stderr, "[I] "); + break; + case WARNING: + if (cur_level_ <= WARNING) fprintf(stderr, "[W] "); + break; + case ERROR: + if (cur_level_ <= ERROR) fprintf(stderr, "[E] "); + break; + case FATAL: + if (cur_level_ <= FATAL) fprintf(stderr, "[F] "); + break; + } + + if (cur_level_ <= level_) { + fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name); + } + } + + ~Logger() noexcept(false) { + static constexpr const char *kErrMsg = R"( + Some bad things happened. Please read the above error messages and stack + trace. If you are using Python, the following command may be helpful: + + gdb --args python /path/to/your/code.py + + (You can use `gdb` to debug the code. Please consider compiling + a debug version of sherpa_onnx.). + + If you are unable to fix it, please open an issue at: + + https://github.com/csukuangfj/kaldi-native-fbank/issues/new + )"; + if (level_ == FATAL) { + fprintf(stderr, "\n"); + std::string stack_trace = GetStackTrace(); + if (!stack_trace.empty()) { + fprintf(stderr, "\n\n%s\n", stack_trace.c_str()); + } + + fflush(nullptr); + +#ifndef __ANDROID_API__ + if (EnableAbort()) { + // NOTE: abort() will terminate the program immediately without + // printing the Python stack backtrace. + abort(); + } + + throw std::runtime_error(kErrMsg); +#else + abort(); +#endif + } + } + + const Logger &operator<<(bool b) const { + if (cur_level_ <= level_) { + fprintf(stderr, b ? "true" : "false"); + } + return *this; + } + + const Logger &operator<<(int8_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%d", i); + return *this; + } + + const Logger &operator<<(const char *s) const { + if (cur_level_ <= level_) fprintf(stderr, "%s", s); + return *this; + } + + const Logger &operator<<(int32_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%d", i); + return *this; + } + + const Logger &operator<<(uint32_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%u", i); + return *this; + } + + const Logger &operator<<(uint64_t i) const { + if (cur_level_ <= level_) + fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT + return *this; + } + + const Logger &operator<<(int64_t i) const { + if (cur_level_ <= level_) + fprintf(stderr, "%lli", (long long int)i); // NOLINT + return *this; + } + + const Logger &operator<<(float f) const { + if (cur_level_ <= level_) fprintf(stderr, "%f", f); + return *this; + } + + const Logger &operator<<(double d) const { + if (cur_level_ <= level_) fprintf(stderr, "%f", d); + return *this; + } + + template + const Logger &operator<<(const T &t) const { + // require T overloads operator<< + std::ostringstream os; + os << t; + return *this << os.str().c_str(); + } + + // specialization to fix compile error: `stringstream << nullptr` is ambiguous + const Logger &operator<<(const std::nullptr_t &null) const { + if (cur_level_ <= level_) *this << "(null)"; + return *this; + } + + private: + const char *filename_; + const char *func_name_; + uint32_t line_num_; + LogLevel level_; + LogLevel cur_level_; +}; +#endif // SHERPA_ONNX_ENABLE_CHECK + +class Voidifier { + public: +#if SHERPA_ONNX_ENABLE_CHECK + void operator&(const Logger &) const {} +#endif +}; +#if !defined(SHERPA_ONNX_ENABLE_CHECK) +template +const Voidifier &operator<<(const Voidifier &v, T &&) { + return v; +} +#endif + +} // namespace sherpa_onnx + +#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "") + +#ifdef SHERPA_ONNX_ENABLE_CHECK + +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \ + defined(__PRETTY_FUNCTION__) +// for clang and GCC +#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__ +#else +// for other compilers +#define SHERPA_ONNX_FUNC __func__ +#endif + +#define SHERPA_ONNX_CHECK(x) \ + (x) ? (void)0 \ + : ::sherpa_onnx::Voidifier() & \ + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ + ::sherpa_onnx::FATAL) \ + << "Check failed: " << #x << " " + +// WARNING: x and y may be evaluated multiple times, but this happens only +// when the check fails. Since the program aborts if it fails, we don't think +// the extra evaluation of x and y matters. +// +// CAUTION: we recommend the following use case: +// +// auto x = Foo(); +// auto y = Bar(); +// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message"; +// +// And please avoid +// +// SHERPA_ONNX_CHECK_EQ(Foo(), Bar()); +// +// if `Foo()` or `Bar()` causes some side effects, e.g., changing some +// local static variables or global variables. +#define _SHERPA_ONNX_CHECK_OP(x, y, op) \ + ((x)op(y)) ? (void)0 \ + : ::sherpa_onnx::Voidifier() & \ + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ + ::sherpa_onnx::FATAL) \ + << "Check failed: " << #x << " " << #op << " " << #y \ + << " (" << (x) << " vs. " << (y) << ") " + +#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==) +#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=) +#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <) +#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=) +#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >) +#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=) + +#define SHERPA_ONNX_LOG(x) \ + ::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_onnx::x) + +// ------------------------------------------------------------ +// For debug check +// ------------------------------------------------------------ +// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank, +// the following macros are in fact empty and does nothing. + +#define SHERPA_ONNX_DCHECK(x) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x) + +#define SHERPA_ONNX_DCHECK_EQ(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y) + +#define SHERPA_ONNX_DCHECK_NE(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y) + +#define SHERPA_ONNX_DCHECK_LT(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y) + +#define SHERPA_ONNX_DCHECK_LE(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y) + +#define SHERPA_ONNX_DCHECK_GT(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y) + +#define SHERPA_ONNX_DCHECK_GE(x, y) \ + ::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y) + +#define SHERPA_ONNX_DLOG(x) \ + ::sherpa_onnx::kDisableDebug \ + ? (void)0 \ + : ::sherpa_onnx::Voidifier() & SHERPA_ONNX_LOG(x) + +#else + +#define SHERPA_ONNX_CHECK(x) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_LOG(x) ::sherpa_onnx::Voidifier() + +#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_onnx::Voidifier() + +#define SHERPA_ONNX_DCHECK(x) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DLOG(x) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_onnx::Voidifier() +#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_onnx::Voidifier() + +#endif // SHERPA_ONNX_CHECK_NE + +#endif // SHERPA_ONNX_CSRC_LOG_H_ diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 29aeca16..6292eb22 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const { os << "OnlineRecognizerConfig("; os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; - os << "tokens=\"" << tokens << "\")"; + os << "tokens=\"" << tokens << "\", "; + os << "endpoint_config=" << endpoint_config.ToString() << ", "; + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; return os.str(); } @@ -47,7 +49,8 @@ class OnlineRecognizer::Impl { explicit Impl(const OnlineRecognizerConfig &config) : config_(config), model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.tokens) { + sym_(config.tokens), + endpoint_(config_.endpoint_config) { decoder_ = std::make_unique(model_.get()); } @@ -64,7 +67,7 @@ class OnlineRecognizer::Impl { s->NumFramesReady(); } - void DecodeStreams(OnlineStream **ss, int32_t n) { + void DecodeStreams(OnlineStream **ss, int32_t n) const { int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); @@ -111,18 +114,44 @@ class OnlineRecognizer::Impl { } } - OnlineRecognizerResult GetResult(OnlineStream *s) { + OnlineRecognizerResult GetResult(OnlineStream *s) const { OnlineTransducerDecoderResult decoder_result = s->GetResult(); decoder_->StripLeadingBlanks(&decoder_result); return Convert(decoder_result, sym_); } + bool IsEndpoint(OnlineStream *s) const { + if (!config_.enable_endpoint) return false; + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const { + // reset result and neural network model state, + // but keep the feature extractor state + + // reset result + s->SetResult(decoder_->GetEmptyResult()); + + // reset neural network model state + s->SetStates(model_->GetEncoderInitStates()); + } + private: OnlineRecognizerConfig config_; std::unique_ptr model_; std::unique_ptr decoder_; SymbolTable sym_; + Endpoint endpoint_; }; OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) @@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { return impl_->IsReady(s); } -void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) { +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { impl_->DecodeStreams(ss, n); } -OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) { +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const { return impl_->GetResult(s); } +bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const { + return impl_->IsEndpoint(s); +} + +void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); } + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 0d85d38c..5066ee25 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -8,6 +8,7 @@ #include #include +#include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" @@ -22,13 +23,21 @@ struct OnlineRecognizerConfig { FeatureExtractorConfig feat_config; OnlineTransducerModelConfig model_config; std::string tokens; + EndpointConfig endpoint_config; + bool enable_endpoint; OnlineRecognizerConfig() = default; OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, const OnlineTransducerModelConfig &model_config, - const std::string &tokens) - : feat_config(feat_config), model_config(model_config), tokens(tokens) {} + const std::string &tokens, + const EndpointConfig &endpoint_config, + bool enable_endpoint) + : feat_config(feat_config), + model_config(model_config), + tokens(tokens), + endpoint_config(endpoint_config), + enable_endpoint(enable_endpoint) {} std::string ToString() const; }; @@ -48,7 +57,7 @@ class OnlineRecognizer { bool IsReady(OnlineStream *s) const; /** Decode a single stream. */ - void DecodeStream(OnlineStream *s) { + void DecodeStream(OnlineStream *s) const { OnlineStream *ss[1] = {s}; DecodeStreams(ss, 1); } @@ -58,9 +67,18 @@ class OnlineRecognizer { * @param ss Pointer array containing streams to be decoded. * @param n Number of streams in `ss`. */ - void DecodeStreams(OnlineStream **ss, int32_t n); + void DecodeStreams(OnlineStream **ss, int32_t n) const; - OnlineRecognizerResult GetResult(OnlineStream *s); + OnlineRecognizerResult GetResult(OnlineStream *s) const; + + // Return true if we detect an endpoint for this stream. + // Note: If this function returns true, you usually want to + // invoke Reset(s). + bool IsEndpoint(OnlineStream *s) const; + + // Clear the state of this stream. If IsEndpoint(s) returns true, + // after calling this function, IsEndpoint(s) will return false + void Reset(OnlineStream *s) const; private: class Impl; diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index a945aa32..42bf6d6e 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -55,7 +55,8 @@ class OnlineStream { int32_t FeatureDim() const; - // Return a reference to the number of processed frames so far. + // Return a reference to the number of processed frames so far + // before subsampling.. // Initially, it is 0. It is always less than NumFramesReady(). // // The returned reference is valid as long as this object is alive. diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 92f4eeaa..6b8eb4cc 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -14,6 +14,9 @@ namespace sherpa_onnx { struct OnlineTransducerDecoderResult { /// The decoded token IDs so far std::vector tokens; + + /// number of trailing blank frames decoded so far + int32_t num_trailing_blanks = 0; }; class OnlineTransducerDecoder { diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 1776805c..c4b9ae15 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( if (y != 0) { emitted = true; (*result)[i].tokens.push_back(y); + (*result)[i].num_trailing_blanks = 0; + } else { + ++(*result)[i].num_trailing_blanks; } } if (emitted) { diff --git a/sherpa-onnx/csrc/parse-options.cc b/sherpa-onnx/csrc/parse-options.cc new file mode 100644 index 00000000..54628949 --- /dev/null +++ b/sherpa-onnx/csrc/parse-options.cc @@ -0,0 +1,774 @@ +// sherpa-onnx/csrc/parse-options.cc +/** + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation; + * Saarland University (Author: Arnab Ghoshal); + * Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); + * Frantisek Skala; Arnab Ghoshal + * Copyright 2013 Tanel Alumae + */ + +// This file is copied and modified from kaldi/src/util/parse-options.cu + +#include "sherpa-onnx/csrc/parse-options.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/log.h" + +#ifdef _MSC_VER +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ + _strtoi64(cur_cstr, end_cstr, 10); +#else +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +namespace sherpa_onnx { + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, Int *out) { + // copied from kaldi/src/util/text-util.h + static_assert(std::is_integral::value, ""); + const char *this_str = str.c_str(); + char *end = nullptr; + errno = 0; + int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end); + if (end != this_str) { + while (isspace(*end)) ++end; + } + if (end == this_str || *end != '\0' || errno != 0) return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + +// copied from kaldi/src/util/text-util.cc +template +class NumberIstream { + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream &operator>>(T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream &ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::unordered_map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = -std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) + : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { + if (po != nullptr && po->other_parser_ != nullptr) { + // we get here if this constructor is used twice, recursively. + other_parser_ = po->other_parser_; + } else { + other_parser_ = po; + } + if (po != nullptr && po->prefix_ != "") { + prefix_ = po->prefix_ + std::string(".") + prefix; + } else { + prefix_ = prefix; + } +} + +void ParseOptions::Register(const std::string &name, bool *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, int32_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, uint32_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, float *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, double *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, std::string *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +// old-style, used for registering application-specific parameters +template +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr, + const std::string &doc) { + if (other_parser_ == nullptr) { + this->RegisterCommon(name, ptr, doc, false); + } else { + SHERPA_ONNX_CHECK(prefix_ != "") + << "prefix: " << prefix_ << "\n" + << "Cannot use empty prefix when registering with prefix."; + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name + other_parser_->Register(new_name, ptr, doc); + } +} + +// does the common part of the job of registering a parameter +template +void ParseOptions::RegisterCommon(const std::string &name, T *ptr, + const std::string &doc, bool is_standard) { + SHERPA_ONNX_CHECK(ptr != nullptr); + std::string idx = name; + NormalizeArgName(&idx); + if (doc_map_.find(idx) != doc_map_.end()) { + SHERPA_ONNX_LOG(WARNING) + << "Registering option twice, ignoring second time: " << name; + } else { + this->RegisterSpecific(name, idx, ptr, doc, is_standard); + } +} + +// used to register standard parameters (those that are present in all of the +// applications) +template +void ParseOptions::RegisterStandard(const std::string &name, T *ptr, + const std::string &doc) { + this->RegisterCommon(name, ptr, doc, true); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, bool *b, + const std::string &doc, bool is_standard) { + bool_map_[idx] = b; + doc_map_[idx] = + DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"), + is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, int32_t *i, + const std::string &doc, bool is_standard) { + int_map_[idx] = i; + std::ostringstream ss; + ss << doc << " (int, default = " << *i << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, uint32_t *u, + const std::string &doc, bool is_standard) { + uint_map_[idx] = u; + std::ostringstream ss; + ss << doc << " (uint, default = " << *u << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, float *f, + const std::string &doc, bool is_standard) { + float_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (float, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, double *f, + const std::string &doc, bool is_standard) { + double_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (double, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, std::string *s, + const std::string &doc, bool is_standard) { + string_map_[idx] = s; + doc_map_[idx] = + DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard); +} + +void ParseOptions::DisableOption(const std::string &name) { + if (argv_ != nullptr) { + SHERPA_ONNX_LOG(FATAL) + << "DisableOption must not be called after calling Read()."; + } + if (doc_map_.erase(name) == 0) { + SHERPA_ONNX_LOG(FATAL) << "Option " << name + << " was not registered so cannot be disabled: "; + } + bool_map_.erase(name); + int_map_.erase(name); + uint_map_.erase(name); + float_map_.erase(name); + double_map_.erase(name); + string_map_.erase(name); +} + +int ParseOptions::NumArgs() const { return positional_args_.size(); } + +std::string ParseOptions::GetArg(int i) const { + if (i < 1 || i > static_cast(positional_args_.size())) { + SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; + } + + return positional_args_[i - 1]; +} + +// We currently do not support any other options. +enum ShellType { kBash = 0 }; + +// This can be changed in the code if it ever does need to be changed (as it's +// unlikely that one compilation of this tool-set would use both shells). +static ShellType kShellType = kBash; + +// Returns true if we need to escape a string before putting it into +// a shell (mainly thinking of bash shell, but should work for others) +// This is for the convenience of the user so command-lines that are +// printed out by ParseOptions::Read (with --print-args=true) are +// paste-able into the shell and will run. If you use a different type of +// shell, it might be necessary to change this function. +// But it's mostly a cosmetic issue as it basically affects how +// the program echoes its command-line arguments to the screen. +static bool MustBeQuoted(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; + + const char *c = str.c_str(); + if (*c == '\0') { + return true; // Must quote empty string + } else { + const char *ok_chars[2]; + + // These seem not to be interpreted as long as there are no other "bad" + // characters involved (e.g. "," would be interpreted as part of something + // like a{b,c}, but not on its own. + ok_chars[kBash] = "[]~#^_-+=:.,/"; + + // Just want to make sure that a space character doesn't get automatically + // inserted here via an automated style-checking script, like it did before. + SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' ')); + + for (; *c != '\0'; ++c) { + // For non-alphanumeric characters we have a list of characters which + // are OK. All others are forbidden (this is easier since the shell + // interprets most non-alphanumeric characters). + if (!isalnum(*c)) { + const char *d; + for (d = ok_chars[st]; *d != '\0'; ++d) { + if (*c == *d) break; + } + // If not alphanumeric or one of the "ok_chars", it must be escaped. + if (*d == '\0') return true; + } + } + return false; // The string was OK. No quoting or escaping. + } +} + +// Returns a quoted and escaped version of "str" +// which has previously been determined to need escaping. +// Our aim is to print out the command line in such a way that if it's +// pasted into a shell of ShellType "st" (only bash for now), it +// will get passed to the program in the same way. +static std::string QuoteAndEscape(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; + + // For now we use the following rules: + // In the normal case, we quote with single-quote "'", and to escape + // a single-quote we use the string: '\'' (interpreted as closing the + // single-quote, putting an escaped single-quote from the shell, and + // then reopening the single quote). + char quote_char = '\''; + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b + + // If the string contains single-quotes that would need escaping this + // way, and we determine that the string could be safely double-quoted + // without requiring any escaping, then we double-quote the string. + // This is the case if the characters "`$\ do not appear in the string. + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html + const char *c_str = str.c_str(); + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) { + quote_char = '"'; + escape_str = "\\\""; // should never be accessed. + } + + char buf[2]; + buf[1] = '\0'; + + buf[0] = quote_char; + std::string ans = buf; + const char *c = str.c_str(); + for (; *c != '\0'; ++c) { + if (*c == quote_char) { + ans += escape_str; + } else { + buf[0] = *c; + ans += buf; + } + } + buf[0] = quote_char; + ans += buf; + return ans; +} + +// static function +std::string ParseOptions::Escape(const std::string &str) { + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str; +} + +int ParseOptions::Read(int argc, const char *const argv[]) { + argc_ = argc; + argv_ = argv; + std::string key, value; + int i; + + // first pass: look for config parameter, look for priority + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // a lone "--" marks the end of named options + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (key.compare("config") == 0) { + ReadConfigFile(value); + } else if (key.compare("help") == 0) { + PrintUsage(); + exit(0); + } + } + } + + bool double_dash_seen = false; + // second pass: add the command line options + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // A lone "--" marks the end of named options. + // Skip that option and break the processing of named options + i += 1; + double_dash_seen = true; + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i]; + } + } else { + break; + } + } + + // process remaining arguments as positional + for (; i < argc; ++i) { + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) { + double_dash_seen = true; + } else { + positional_args_.push_back(std::string(argv[i])); + } + } + + // if the user did not suppress this with --print-args = false.... + if (print_args_) { + std::ostringstream strm; + for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; + strm << '\n'; + SHERPA_ONNX_LOG(INFO) << strm.str(); + } + return i; +} + +void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { + std::ostringstream os; + os << '\n' << usage_ << '\n'; + // first we print application-specific options + bool app_specific_header_printed = false; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == false) { // application-specific option + if (app_specific_header_printed == false) { // header was not yet printed + os << "Options:" << '\n'; + app_specific_header_printed = true; + } + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " + << it->second.use_msg_ << '\n'; + } + } + if (app_specific_header_printed == true) { + os << '\n'; + } + + // then the standard options + os << "Standard options:" << '\n'; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == true) { // we have standard option + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " + << it->second.use_msg_ << '\n'; + } + } + os << '\n'; + if (print_command_line) { + std::ostringstream strm; + strm << "Command line was: "; + for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " "; + strm << '\n'; + os << strm.str(); + } + + SHERPA_ONNX_LOG(INFO) << os.str(); +} + +void ParseOptions::PrintConfig(std::ostream &os) const { + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n'; + std::string key; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + key = it->first; + os << it->second.name_ << " = "; + if (bool_map_.end() != bool_map_.find(key)) { + os << (*bool_map_.at(key) ? "true" : "false"); + } else if (int_map_.end() != int_map_.find(key)) { + os << (*int_map_.at(key)); + } else if (uint_map_.end() != uint_map_.find(key)) { + os << (*uint_map_.at(key)); + } else if (float_map_.end() != float_map_.find(key)) { + os << (*float_map_.at(key)); + } else if (double_map_.end() != double_map_.find(key)) { + os << (*double_map_.at(key)); + } else if (string_map_.end() != string_map_.find(key)) { + os << "'" << *string_map_.at(key) << "'"; + } else { + SHERPA_ONNX_LOG(FATAL) + << "PrintConfig: unrecognized option " << key << "[code error]"; + } + os << '\n'; + } + os << '\n'; +} + +void ParseOptions::ReadConfigFile(const std::string &filename) { + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is.good()) { + SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename; + } + + std::string line, key, value; + int32_t line_number = 0; + while (std::getline(is, line)) { + ++line_number; + // trim out the comments + size_t pos; + if ((pos = line.find_first_of('#')) != std::string::npos) { + line.erase(pos); + } + // skip empty lines + Trim(&line); + if (line.length() == 0) continue; + + if (line.substr(0, 2) != "--") { + SHERPA_ONNX_LOG(FATAL) + << "Reading config file " << filename << ": line " << line_number + << " does not look like a line " + << "from a Kaldi command-line program's config file: should " + << "be of the form --x=y. Note: config files intended to " + << "be sourced by shell scripts lack the '--'."; + } + + // parse option + bool has_equal_sign; + SplitLongArg(line, &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file " + << filename << ": line " << line_number; + } + } +} + +void ParseOptions::SplitLongArg(const std::string &in, std::string *key, + std::string *value, + bool *has_equal_sign) const { + SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition. + size_t pos = in.find_first_of('=', 0); + if (pos == std::string::npos) { // we allow --option for bools + // defaults to empty. We handle this differently in different cases. + *key = in.substr(2, in.size() - 2); // 2 because starts with --. + *value = ""; + *has_equal_sign = false; + } else if (pos == 2) { // we also don't allow empty keys: --=value + PrintUsage(true); + SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in; + } else { // normal case: --option=value + *key = in.substr(2, pos - 2); // 2 because starts with --. + *value = in.substr(pos + 1); + *has_equal_sign = true; + } +} + +void ParseOptions::NormalizeArgName(std::string *str) const { + std::string out; + std::string::iterator it; + + for (it = str->begin(); it != str->end(); ++it) { + if (*it == '_') { + out += '-'; // convert _ to - + } else { + out += std::tolower(*it); + } + } + *str = out; + + SHERPA_ONNX_CHECK_GT(str->length(), 0); +} + +void ParseOptions::Trim(std::string *str) const { + const char *white_chars = " \t\n\r\f\v"; + + std::string::size_type pos = str->find_last_not_of(white_chars); + if (pos != std::string::npos) { + str->erase(pos + 1); + pos = str->find_first_not_of(white_chars); + if (pos != std::string::npos) str->erase(0, pos); + } else { + str->erase(str->begin(), str->end()); + } +} + +bool ParseOptions::SetOption(const std::string &key, const std::string &value, + bool has_equal_sign) { + if (bool_map_.end() != bool_map_.find(key)) { + if (has_equal_sign && value == "") { + SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "="; + } + *(bool_map_[key]) = ToBool(value); + } else if (int_map_.end() != int_map_.find(key)) { + *(int_map_[key]) = ToInt(value); + } else if (uint_map_.end() != uint_map_.find(key)) { + *(uint_map_[key]) = ToUint(value); + } else if (float_map_.end() != float_map_.find(key)) { + *(float_map_[key]) = ToFloat(value); + } else if (double_map_.end() != double_map_.find(key)) { + *(double_map_[key]) = ToDouble(value); + } else if (string_map_.end() != string_map_.find(key)) { + if (!has_equal_sign) { + SHERPA_ONNX_LOG(FATAL) + << "Invalid option --" << key << " (option format is --x=y)."; + } + *(string_map_[key]) = value; + } else { + return false; + } + return true; +} + +bool ParseOptions::ToBool(std::string str) const { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + // allow "" as a valid option for "true", so that --x is the same as --x=true + if ((str.compare("true") == 0) || (str.compare("t") == 0) || + (str.compare("1") == 0) || (str.compare("") == 0)) { + return true; + } + if ((str.compare("false") == 0) || (str.compare("f") == 0) || + (str.compare("0") == 0)) { + return false; + } + // if it is neither true nor false: + PrintUsage(true); + SHERPA_ONNX_LOG(FATAL) + << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + +int32_t ParseOptions::ToInt(const std::string &str) const { + int32_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + return ret; +} + +uint32_t ParseOptions::ToUint(const std::string &str) const { + uint32_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) + SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + return ret; +} + +float ParseOptions::ToFloat(const std::string &str) const { + float ret; + if (!ConvertStringToReal(str, &ret)) + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +double ParseOptions::ToDouble(const std::string &str) const { + double ret; + if (!ConvertStringToReal(str, &ret)) + SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +// instantiate templates +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + int32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + float *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + double *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + int32_t *ptr, const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + uint32_t *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, float *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, double *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + std::string *ptr, + const std::string &doc, + bool is_standard); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/parse-options.h b/sherpa-onnx/csrc/parse-options.h new file mode 100644 index 00000000..a46f794f --- /dev/null +++ b/sherpa-onnx/csrc/parse-options.h @@ -0,0 +1,252 @@ +// sherpa-onnx/csrc/parse-options.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +// +// This file is copied and modified from kaldi/src/util/parse-options.h + +#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ +#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ + +#include +#include +#include +#include + +namespace sherpa_onnx { + +class ParseOptions { + public: + explicit ParseOptions(const char *usage) + : print_args_(true), + help_(false), + usage_(usage), + argc_(0), + argv_(nullptr), + prefix_(""), + other_parser_(nullptr) { +#if !defined(_MSC_VER) && !defined(__CYGWIN__) + // This is just a convenient place to set the stderr to line + // buffering mode, since it's called at program start. + // This helps ensure different programs' output is not mixed up. + setlinebuf(stderr); +#endif + RegisterStandard("config", &config_, + "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, ParseOptions *other); + + ParseOptions(const ParseOptions &) = delete; + ParseOptions &operator=(const ParseOptions &) = delete; + ~ParseOptions() = default; + + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32_t *ptr, const std::string &doc); + void Register(const std::string &name, uint32_t *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template + void RegisterStandard(const std::string &name, T *ptr, + const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line values given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false) const; + + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os) const; + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + /// + /// Note: Index is 1 based. + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32_t *i, const std::string &doc, bool is_standard); + /// Register unsigned int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32_t *u, const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template + void RegisterCommon(const std::string &name, T *ptr, const std::string &doc, + bool is_standard); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str) const; + int32_t ToInt(const std::string &str) const; + uint32_t ToUint(const std::string &str) const; + float ToFloat(const std::string &str) const; + double ToDouble(const std::string &str) const; + + // maps for option variables + std::unordered_map bool_map_; + std::unordered_map int_map_; + std::unordered_map uint_map_; + std::unordered_map float_map_; + std::unordered_map double_map_; + std::unordered_map string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() = default; + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + using DocMapType = std::unordered_map; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + ParseOptions *other_parser_; + + protected: + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(const std::string &in, std::string *key, std::string *value, + bool *has_equal_sign) const; + + void NormalizeArgName(std::string *str) const; + + /// Removes the beginning and trailing whitespaces from a string + void Trim(std::string *str) const; +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(ParseOptions *opts)" which it can call to register the +/// ParseOptions object. +template +void ReadConfigFromFile(const std::string &config_filename, C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template +void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << conf << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(conf); +} + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc new file mode 100644 index 00000000..83e79b8a --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc @@ -0,0 +1,142 @@ +// sherpa-onnx/csrc/sherpa-onnx-alsa.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#include +#include +#include + +#include +#include // std::tolower +#include + +#include "sherpa-onnx/csrc/alsa.h" +#include "sherpa-onnx/csrc/display.h" +#include "sherpa-onnx/csrc/online-recognizer.h" + +bool stop = false; + +static void Handler(int sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int main(int32_t argc, char *argv[]) { + if (argc < 6 || argc > 7) { + const char *usage = R"usage( +Usage: + ./bin/sherpa-onnx-alsa \ + /path/to/tokens.txt \ + /path/to/encoder.onnx \ + /path/to/decoder.onnx \ + /path/to/joiner.onnx \ + device_name \ + [num_threads] + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on you system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + hw:3,0 + +as the device_name. +)usage"; + + fprintf(stderr, "%s\n", usage); + fprintf(stderr, "argc, %d\n", argc); + + return 0; + } + + signal(SIGINT, Handler); + + sherpa_onnx::OnlineRecognizerConfig config; + + config.tokens = argv[1]; + + config.model_config.debug = false; + config.model_config.encoder_filename = argv[2]; + config.model_config.decoder_filename = argv[3]; + config.model_config.joiner_filename = argv[4]; + + const char *device_name = argv[5]; + + config.model_config.num_threads = 2; + if (argc == 7 && atoi(argv[6]) > 0) { + config.model_config.num_threads = atoi(argv[6]); + } + + config.enable_endpoint = true; + + config.endpoint_config.rule1.min_trailing_silence = 2.4; + config.endpoint_config.rule2.min_trailing_silence = 1.2; + config.endpoint_config.rule3.min_utterance_length = 300; + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + sherpa_onnx::OnlineRecognizer recognizer(config); + + int32_t expected_sample_rate = config.feat_config.sampling_rate; + + sherpa_onnx::Alsa alsa(device_name); + fprintf(stderr, "Use recording device: %s\n", device_name); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + std::string last_text; + + auto stream = recognizer.CreateStream(); + + sherpa_onnx::Display display; + + int32_t segment_index = 0; + while (!stop) { + const std::vector samples = alsa.Read(chunk); + + stream->AcceptWaveform(expected_sample_rate, samples.data(), + samples.size()); + + while (recognizer.IsReady(stream.get())) { + recognizer.DecodeStream(stream.get()); + } + + auto text = recognizer.GetResult(stream.get()).text; + + bool is_endpoint = recognizer.IsEndpoint(stream.get()); + + if (!text.empty() && last_text != text) { + last_text = text; + + std::transform(text.begin(), text.end(), text.begin(), + [](auto c) { return std::tolower(c); }); + + display.Print(segment_index, text); + } + + if (!text.empty() && is_endpoint) { + ++segment_index; + recognizer.Reset(stream.get()); + } + } + + return 0; +} diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index e2efa3f8..73edbd10 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx features.cc online-transducer-model-config.cc sherpa-onnx.cc + endpoint.cc online-stream.cc online-recognizer.cc ) diff --git a/sherpa-onnx/python/csrc/endpoint.cc b/sherpa-onnx/python/csrc/endpoint.cc new file mode 100644 index 00000000..52985c49 --- /dev/null +++ b/sherpa-onnx/python/csrc/endpoint.cc @@ -0,0 +1,100 @@ +// sherpa-onnx/csrc/endpoint.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/endpoint.h" + +#include +#include + +#include "sherpa-onnx/csrc/endpoint.h" + +namespace sherpa_onnx { + +static constexpr const char *kEndpointRuleInitDoc = R"doc( +Constructor for EndpointRule. + +Args: + must_contain_nonsilence: + If True, for this endpointing rule to apply there must be nonsilence in the + best-path traceback. For decoding, a non-blank token is considered as + non-silence. + min_trailing_silence: + This endpointing rule requires duration of trailing silence (in seconds) + to be ``>=`` this value. + min_utterance_length: + This endpointing rule requires utterance-length (in seconds) to + be ``>=`` this value. +)doc"; + +static constexpr const char *kEndpointConfigInitDoc = R"doc( +If any rule in EndpointConfig is activated, it is said that an endpointing +is detected. + +Args: + rule1: + By default, it times out after 2.4 seconds of silence, even if + we decoded nothing. + rule2: + By default, it times out after 1.2 seconds of silence after decoding + something. + rule3: + By default, it times out after the utterance is 20 seconds long, regardless of + anything else. +)doc"; + +static void PybindEndpointRule(py::module *m) { + using PyClass = EndpointRule; + py::class_(*m, "EndpointRule") + .def(py::init(), py::arg("must_contain_nonsilence"), + py::arg("min_trailing_silence"), py::arg("min_utterance_length"), + kEndpointRuleInitDoc) + .def("__str__", &PyClass::ToString) + .def_readwrite("must_contain_nonsilence", + &PyClass::must_contain_nonsilence) + .def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence) + .def_readwrite("min_utterance_length", &PyClass::min_utterance_length); +} + +static void PybindEndpointConfig(py::module *m) { + using PyClass = EndpointConfig; + py::class_(*m, "EndpointConfig") + .def( + py::init( + [](float rule1_min_trailing_silence, + float rule2_min_trailing_silence, + float rule3_min_utterance_length) -> std::unique_ptr { + EndpointRule rule1(false, rule1_min_trailing_silence, 0); + EndpointRule rule2(true, rule2_min_trailing_silence, 0); + EndpointRule rule3(false, 0, rule3_min_utterance_length); + + return std::make_unique(rule1, rule2, rule3); + }), + py::arg("rule1_min_trailing_silence"), + py::arg("rule2_min_trailing_silence"), + py::arg("rule3_min_utterance_length")) + .def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2, + const EndpointRule &rule3) -> std::unique_ptr { + auto ans = std::make_unique(); + ans->rule1 = rule1; + ans->rule2 = rule2; + ans->rule3 = rule3; + return ans; + }), + py::arg("rule1") = EndpointRule(false, 2.4, 0), + py::arg("rule2") = EndpointRule(true, 1.2, 0), + py::arg("rule3") = EndpointRule(false, 0, 20), + kEndpointConfigInitDoc) + .def("__str__", + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def_readwrite("rule1", &PyClass::rule1) + .def_readwrite("rule2", &PyClass::rule2) + .def_readwrite("rule3", &PyClass::rule3); +} + +void PybindEndpoint(py::module *m) { + PybindEndpointRule(m); + PybindEndpointConfig(m); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/endpoint.h b/sherpa-onnx/python/csrc/endpoint.h new file mode 100644 index 00000000..f4d3e5b7 --- /dev/null +++ b/sherpa-onnx/python/csrc/endpoint.h @@ -0,0 +1,16 @@ +// sherpa-onnx/csrc/endpoint.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindEndpoint(py::module *m); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 52c74f23..ff22d74a 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), - py::arg("feat_config"), py::arg("model_config"), py::arg("tokens")) + const OnlineTransducerModelConfig &, const std::string &, + const EndpointConfig &, bool>(), + py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"), + py::arg("endpoint_config"), py::arg("enable_endpoint")) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("endpoint_config", &PyClass::endpoint_config) + .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def("__str__", &PyClass::ToString); } @@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) { [](PyClass &self, std::vector ss) { self.DecodeStreams(ss.data(), ss.size()); }) - .def("get_result", &PyClass::GetResult); + .def("get_result", &PyClass::GetResult) + .def("is_endpoint", &PyClass::IsEndpoint) + .def("reset", &PyClass::Reset); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index cca04c09..4d6a798c 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/python/csrc/sherpa-onnx.h" +#include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" @@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindFeatures(&m); PybindOnlineTransducerModelConfig(&m); PybindOnlineStream(&m); + PybindEndpoint(&m); PybindOnlineRecognizer(&m); } diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 60a03cc2..e50b5f2e 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,4 +1,5 @@ from _sherpa_onnx import ( + EndpointConfig, FeatureExtractorConfig, OnlineRecognizerConfig, OnlineStream, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 90ba196f..9e992ce0 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -2,12 +2,13 @@ from pathlib import Path from typing import List from _sherpa_onnx import ( + EndpointConfig, + FeatureExtractorConfig, + OnlineRecognizer as _Recognizer, + OnlineRecognizerConfig, OnlineStream, OnlineTransducerModelConfig, - FeatureExtractorConfig, - OnlineRecognizerConfig, ) -from _sherpa_onnx import OnlineRecognizer as _Recognizer def _assert_file_exists(f: str): @@ -26,6 +27,10 @@ class OnlineRecognizer(object): num_threads: int = 4, sample_rate: float = 16000, feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: int = 2.4, + rule2_min_trailing_silence: int = 1.2, + rule3_min_utterance_length: int = 20, ): """ Please refer to @@ -52,6 +57,22 @@ class OnlineRecognizer(object): Sample rate of the training data used to train the model. feature_dim: Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -72,10 +93,18 @@ class OnlineRecognizer(object): feature_dim=feature_dim, ) + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, tokens=tokens, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, ) self.recognizer = _Recognizer(recognizer_config) @@ -93,4 +122,10 @@ class OnlineRecognizer(object): return self.recognizer.is_ready(s) def get_result(self, s: OnlineStream) -> str: - return self.recognizer.get_result(s).text + return self.recognizer.get_result(s).text.strip() + + def is_endpoint(self, s: OnlineStream) -> bool: + return self.recognizer.is_endpoint(s) + + def reset(self, s: OnlineStream) -> bool: + return self.recognizer.reset(s)