Add endpointing (#54)
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -4,9 +4,11 @@ build
|
||||
onnxruntime-*
|
||||
icefall-*
|
||||
run.sh
|
||||
sherpa-onnx-*
|
||||
__pycache__
|
||||
dist/
|
||||
sherpa_onnx.egg-info/
|
||||
.DS_Store
|
||||
build-aarch64-linux-gnu
|
||||
sherpa-onnx-streaming-zipformer-*
|
||||
sherpa-onnx-lstm-en-*
|
||||
sherpa-onnx-lstm-zh-*
|
||||
|
||||
@@ -13,6 +13,7 @@ endif()
|
||||
|
||||
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
|
||||
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||
@@ -46,6 +47,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
||||
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
@@ -56,6 +59,9 @@ if(SHERPA_ONNX_HAS_ALSA)
|
||||
add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1)
|
||||
endif()
|
||||
|
||||
check_include_file_cxx(cxxabi.h SHERPA_ONNX_HAVE_CXXABI_H)
|
||||
check_include_file_cxx(execinfo.h SHERPA_ONNX_HAVE_EXECINFO_H)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Real-time speech recognition from a microphone with sherpa-onnx Python API
|
||||
# with endpoint detection.
|
||||
#
|
||||
# Please refer to
|
||||
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
# to download pre-trained models
|
||||
|
||||
import sys
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as e:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
print()
|
||||
print("to install it")
|
||||
sys.exit(-1)
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
# Please replace the model files if needed.
|
||||
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
# for download links.
|
||||
recognizer = sherpa_onnx.OnlineRecognizer(
|
||||
tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt",
|
||||
encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx",
|
||||
decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx",
|
||||
joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx",
|
||||
num_threads=4,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
enable_endpoint_detection=True,
|
||||
rule1_min_trailing_silence=2.4,
|
||||
rule2_min_trailing_silence=1.2,
|
||||
rule3_min_utterance_length=300, # it essentially disables this rule
|
||||
)
|
||||
return recognizer
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
recognizer = create_recognizer()
|
||||
sample_rate = 16000
|
||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
last_result = ""
|
||||
stream = recognizer.create_stream()
|
||||
|
||||
last_result = ""
|
||||
segment_id = 0
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
while True:
|
||||
samples, _ = s.read(samples_per_read) # a blocking read
|
||||
samples = samples.reshape(-1)
|
||||
stream.accept_waveform(sample_rate, samples)
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
|
||||
is_endpoint = recognizer.is_endpoint(stream)
|
||||
|
||||
result = recognizer.get_result(stream)
|
||||
|
||||
if result and (last_result != result):
|
||||
last_result = result
|
||||
print(f"{segment_id}: {result}")
|
||||
|
||||
if result and is_endpoint:
|
||||
segment_id += 1
|
||||
recognizer.reset(stream)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = sd.query_devices()
|
||||
print(devices)
|
||||
default_input_device_idx = sd.default.device[0]
|
||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
||||
@@ -1,7 +1,8 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
add_library(sherpa-onnx-core
|
||||
set(sources
|
||||
cat.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
@@ -11,6 +12,7 @@ add_library(sherpa-onnx-core
|
||||
online-transducer-model.cc
|
||||
online-zipformer-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
parse-options.cc
|
||||
resample.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
@@ -18,11 +20,29 @@ add_library(sherpa-onnx-core
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_CHECK)
|
||||
list(APPEND sources log.cc)
|
||||
endif()
|
||||
|
||||
add_library(sherpa-onnx-core ${sources})
|
||||
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_CHECK)
|
||||
target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
|
||||
|
||||
if(SHERPA_ONNX_HAVE_EXECINFO_H)
|
||||
target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_EXECINFO_H=1)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_HAVE_CXXABI_H)
|
||||
target_compile_definitions(sherpa-onnx-core PRIVATE SHERPA_ONNX_HAVE_CXXABI_H=1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
|
||||
target_link_libraries(sherpa-onnx sherpa-onnx-core)
|
||||
|
||||
91
sherpa-onnx/csrc/endpoint.cc
Normal file
91
sherpa-onnx/csrc/endpoint.cc
Normal file
@@ -0,0 +1,91 @@
|
||||
// sherpa-onnx/csrc/endpoint.cc
|
||||
//
|
||||
// Copyright (c) 2022 (authors: Pingfeng Luo)
|
||||
// 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool RuleActivated(const EndpointRule &rule,
|
||||
const std::string &rule_name, float trailing_silence,
|
||||
float utterance_length) {
|
||||
bool contain_nonsilence = utterance_length > trailing_silence;
|
||||
bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) &&
|
||||
trailing_silence >= rule.min_trailing_silence &&
|
||||
utterance_length >= rule.min_utterance_length;
|
||||
if (ans) {
|
||||
SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: "
|
||||
<< (contain_nonsilence ? "true" : "false") << ','
|
||||
<< trailing_silence << ',' << utterance_length;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule,
|
||||
const std::string &rule_name) {
|
||||
po->Register(
|
||||
rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence,
|
||||
"If True, for this endpointing " + rule_name +
|
||||
" to apply there must be nonsilence in the best-path traceback. "
|
||||
"For decoding, a non-blank token is considered as non-silence");
|
||||
po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence,
|
||||
"This endpointing " + rule_name +
|
||||
" requires duration of trailing silence in seconds) to "
|
||||
"be >= this value.");
|
||||
po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length,
|
||||
"This endpointing " + rule_name +
|
||||
" requires utterance-length (in seconds) to be >= this "
|
||||
"value.");
|
||||
}
|
||||
|
||||
std::string EndpointRule::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "EndpointRule(";
|
||||
os << "must_contain_nonsilence="
|
||||
<< (must_contain_nonsilence ? "True" : "False") << ", ";
|
||||
os << "min_trailing_silence=" << min_trailing_silence << ", ";
|
||||
os << "min_utterance_length=" << min_utterance_length << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void EndpointConfig::Register(ParseOptions *po) {
|
||||
RegisterEndpointRule(po, &rule1, "rule1");
|
||||
RegisterEndpointRule(po, &rule2, "rule2");
|
||||
RegisterEndpointRule(po, &rule3, "rule3");
|
||||
}
|
||||
|
||||
std::string EndpointConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "EndpointConfig(";
|
||||
os << "rule1=" << rule1.ToString() << ", ";
|
||||
os << "rule2=" << rule2.ToString() << ", ";
|
||||
os << "rule3=" << rule3.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool Endpoint::IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
|
||||
float frame_shift_in_seconds) const {
|
||||
float utterance_length = num_frames_decoded * frame_shift_in_seconds;
|
||||
float trailing_silence = trailing_silence_frames * frame_shift_in_seconds;
|
||||
if (RuleActivated(config_.rule1, "rule1", trailing_silence,
|
||||
utterance_length) ||
|
||||
RuleActivated(config_.rule2, "rule2", trailing_silence,
|
||||
utterance_length) ||
|
||||
RuleActivated(config_.rule3, "rule3", trailing_silence,
|
||||
utterance_length)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
76
sherpa-onnx/csrc/endpoint.h
Normal file
76
sherpa-onnx/csrc/endpoint.h
Normal file
@@ -0,0 +1,76 @@
|
||||
// sherpa-onnx/csrc/endpoint.h
|
||||
//
|
||||
// Copyright (c) 2022 (authors: Pingfeng Luo)
|
||||
// 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_
|
||||
#define SHERPA_ONNX_CSRC_ENDPOINT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct EndpointRule {
|
||||
// If True, for this endpointing rule to apply there must
|
||||
// be nonsilence in the best-path traceback.
|
||||
// For decoding, a non-blank token is considered as non-silence
|
||||
bool must_contain_nonsilence = true;
|
||||
// This endpointing rule requires duration of trailing silence
|
||||
// (in seconds) to be >= this value.
|
||||
float min_trailing_silence = 2.0;
|
||||
// This endpointing rule requires utterance-length (in seconds)
|
||||
// to be >= this value.
|
||||
float min_utterance_length = 0.0f;
|
||||
|
||||
EndpointRule() = default;
|
||||
|
||||
EndpointRule(bool must_contain_nonsilence, float min_trailing_silence,
|
||||
float min_utterance_length)
|
||||
: must_contain_nonsilence(must_contain_nonsilence),
|
||||
min_trailing_silence(min_trailing_silence),
|
||||
min_utterance_length(min_utterance_length) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class ParseOptions;
|
||||
|
||||
struct EndpointConfig {
|
||||
// For default setting,
|
||||
// rule1 times out after 2.4 seconds of silence, even if we decoded nothing.
|
||||
// rule2 times out after 1.2 seconds of silence after decoding something.
|
||||
// rule3 times out after the utterance is 20 seconds long, regardless of
|
||||
// anything else.
|
||||
EndpointRule rule1;
|
||||
EndpointRule rule2;
|
||||
EndpointRule rule3;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
EndpointConfig()
|
||||
: rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {}
|
||||
|
||||
EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2,
|
||||
const EndpointRule &rule3)
|
||||
: rule1(rule1), rule2(rule2), rule3(rule3) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class Endpoint {
|
||||
public:
|
||||
explicit Endpoint(const EndpointConfig &config) : config_(config) {}
|
||||
|
||||
/// This function returns true if this set of endpointing rules thinks we
|
||||
/// should terminate decoding.
|
||||
bool IsEndpoint(int num_frames_decoded, int trailing_silence_frames,
|
||||
float frame_shift_in_seconds) const;
|
||||
|
||||
private:
|
||||
EndpointConfig config_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_
|
||||
122
sherpa-onnx/csrc/log.cc
Normal file
122
sherpa-onnx/csrc/log.cc
Normal file
@@ -0,0 +1,122 @@
|
||||
// sherpa-onnx/csrc/log.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
#include <execinfo.h> // To get stack trace in error messages.
|
||||
#ifdef SHERPA_ONNX_HAVE_CXXABI_H
|
||||
#include <cxxabi.h> // For name demangling.
|
||||
// Useful to decode the stack trace, but only used if we have execinfo.h
|
||||
#endif // SHERPA_ONNX_HAVE_CXXABI_H
|
||||
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <ctime>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string GetDateTimeStr() {
|
||||
std::ostringstream os;
|
||||
std::time_t t = std::time(nullptr);
|
||||
std::tm tm = *std::localtime(&t);
|
||||
os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin,
|
||||
std::size_t *end) {
|
||||
// Find the first '_' with leading ' ' or '('.
|
||||
*begin = std::string::npos;
|
||||
for (std::size_t i = 1; i < trace_name.size(); ++i) {
|
||||
if (trace_name[i] != '_') {
|
||||
continue;
|
||||
}
|
||||
if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') {
|
||||
*begin = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (*begin == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
*end = trace_name.find_first_of(" +", *begin);
|
||||
return *end != std::string::npos;
|
||||
}
|
||||
|
||||
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
static std::string Demangle(const std::string &trace_name) {
|
||||
#ifndef SHERPA_ONNX_HAVE_CXXABI_H
|
||||
return trace_name;
|
||||
#else // SHERPA_ONNX_HAVE_CXXABI_H
|
||||
// Try demangle the symbol. We are trying to support the following formats
|
||||
// produced by different platforms:
|
||||
//
|
||||
// Linux:
|
||||
// ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d]
|
||||
//
|
||||
// Mac:
|
||||
// 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813
|
||||
//
|
||||
// We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and
|
||||
// demangle it info a readable name like kaldi::UnitTextError.
|
||||
std::size_t begin, end;
|
||||
if (!LocateSymbolRange(trace_name, &begin, &end)) {
|
||||
return trace_name;
|
||||
}
|
||||
std::string symbol = trace_name.substr(begin, end - begin);
|
||||
int status;
|
||||
char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status);
|
||||
if (status == 0 && demangled_name != nullptr) {
|
||||
symbol = demangled_name;
|
||||
free(demangled_name);
|
||||
}
|
||||
return trace_name.substr(0, begin) + symbol +
|
||||
trace_name.substr(end, std::string::npos);
|
||||
#endif // SHERPA_ONNX_HAVE_CXXABI_H
|
||||
}
|
||||
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
|
||||
std::string GetStackTrace() {
|
||||
std::string ans;
|
||||
#ifdef SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
constexpr const std::size_t kMaxTraceSize = 50;
|
||||
constexpr const std::size_t kMaxTracePrint = 50; // Must be even.
|
||||
// Buffer for the trace.
|
||||
void *trace[kMaxTraceSize];
|
||||
// Get the trace.
|
||||
std::size_t size = backtrace(trace, kMaxTraceSize);
|
||||
// Get the trace symbols.
|
||||
char **trace_symbol = backtrace_symbols(trace, size);
|
||||
if (trace_symbol == nullptr) return ans;
|
||||
|
||||
// Compose a human-readable backtrace string.
|
||||
ans += "[ Stack-Trace: ]\n";
|
||||
if (size <= kMaxTracePrint) {
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
ans += Demangle(trace_symbol[i]) + "\n";
|
||||
}
|
||||
} else { // Print out first+last (e.g.) 5.
|
||||
for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) {
|
||||
ans += Demangle(trace_symbol[i]) + "\n";
|
||||
}
|
||||
ans += ".\n.\n.\n";
|
||||
for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) {
|
||||
ans += Demangle(trace_symbol[i]) + "\n";
|
||||
}
|
||||
if (size == kMaxTraceSize)
|
||||
ans += ".\n.\n.\n"; // Stack was too long, probably a bug.
|
||||
}
|
||||
|
||||
// We must free the array of pointers allocated by backtrace_symbols(),
|
||||
// but not the strings themselves.
|
||||
free(trace_symbol);
|
||||
#endif // SHERPA_ONNX_HAVE_EXECINFO_H
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
378
sherpa-onnx/csrc/log.h
Normal file
378
sherpa-onnx/csrc/log.h
Normal file
@@ -0,0 +1,378 @@
|
||||
// sherpa-onnx/csrc/log.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_LOG_H_
|
||||
#define SHERPA_ONNX_CSRC_LOG_H_
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_CHECK
|
||||
|
||||
#if defined(NDEBUG)
|
||||
constexpr bool kDisableDebug = true;
|
||||
#else
|
||||
constexpr bool kDisableDebug = false;
|
||||
#endif
|
||||
|
||||
enum class LogLevel {
|
||||
kTrace = 0,
|
||||
kDebug = 1,
|
||||
kInfo = 2,
|
||||
kWarning = 3,
|
||||
kError = 4,
|
||||
kFatal = 5, // print message and abort the program
|
||||
};
|
||||
|
||||
// They are used in SHERPA_ONNX_LOG(xxx), so their names
|
||||
// do not follow the google c++ code style
|
||||
//
|
||||
// You can use them in the following way:
|
||||
//
|
||||
// SHERPA_ONNX_LOG(TRACE) << "some message";
|
||||
// SHERPA_ONNX_LOG(DEBUG) << "some message";
|
||||
#ifndef _MSC_VER
|
||||
constexpr LogLevel TRACE = LogLevel::kTrace;
|
||||
constexpr LogLevel DEBUG = LogLevel::kDebug;
|
||||
constexpr LogLevel INFO = LogLevel::kInfo;
|
||||
constexpr LogLevel WARNING = LogLevel::kWarning;
|
||||
constexpr LogLevel ERROR = LogLevel::kError;
|
||||
constexpr LogLevel FATAL = LogLevel::kFatal;
|
||||
#else
|
||||
#define TRACE LogLevel::kTrace
|
||||
#define DEBUG LogLevel::kDebug
|
||||
#define INFO LogLevel::kInfo
|
||||
#define WARNING LogLevel::kWarning
|
||||
#define ERROR LogLevel::kError
|
||||
#define FATAL LogLevel::kFatal
|
||||
#endif
|
||||
|
||||
std::string GetStackTrace();
|
||||
|
||||
/* Return the current log level.
|
||||
|
||||
|
||||
If the current log level is TRACE, then all logged messages are printed out.
|
||||
|
||||
If the current log level is DEBUG, log messages with "TRACE" level are not
|
||||
shown and all other levels are printed out.
|
||||
|
||||
Similarly, if the current log level is INFO, log message with "TRACE" and
|
||||
"DEBUG" are not shown and all other levels are printed out.
|
||||
|
||||
If it is FATAL, then only FATAL messages are shown.
|
||||
*/
|
||||
inline LogLevel GetCurrentLogLevel() {
|
||||
static LogLevel log_level = INFO;
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, []() {
|
||||
const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL");
|
||||
if (env_log_level == nullptr) return;
|
||||
|
||||
std::string s = env_log_level;
|
||||
if (s == "TRACE")
|
||||
log_level = TRACE;
|
||||
else if (s == "DEBUG")
|
||||
log_level = DEBUG;
|
||||
else if (s == "INFO")
|
||||
log_level = INFO;
|
||||
else if (s == "WARNING")
|
||||
log_level = WARNING;
|
||||
else if (s == "ERROR")
|
||||
log_level = ERROR;
|
||||
else if (s == "FATAL")
|
||||
log_level = FATAL;
|
||||
else
|
||||
fprintf(stderr,
|
||||
"Unknown SHERPA_ONNX_LOG_LEVEL: %s"
|
||||
"\nSupported values are: "
|
||||
"TRACE, DEBUG, INFO, WARNING, ERROR, FATAL",
|
||||
s.c_str());
|
||||
});
|
||||
return log_level;
|
||||
}
|
||||
|
||||
inline bool EnableAbort() {
|
||||
static std::once_flag init_flag;
|
||||
static bool enable_abort = false;
|
||||
std::call_once(init_flag, []() {
|
||||
enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr);
|
||||
});
|
||||
return enable_abort;
|
||||
}
|
||||
|
||||
class Logger {
|
||||
public:
|
||||
Logger(const char *filename, const char *func_name, uint32_t line_num,
|
||||
LogLevel level)
|
||||
: filename_(filename),
|
||||
func_name_(func_name),
|
||||
line_num_(line_num),
|
||||
level_(level) {
|
||||
cur_level_ = GetCurrentLogLevel();
|
||||
switch (level) {
|
||||
case TRACE:
|
||||
if (cur_level_ <= TRACE) fprintf(stderr, "[T] ");
|
||||
break;
|
||||
case DEBUG:
|
||||
if (cur_level_ <= DEBUG) fprintf(stderr, "[D] ");
|
||||
break;
|
||||
case INFO:
|
||||
if (cur_level_ <= INFO) fprintf(stderr, "[I] ");
|
||||
break;
|
||||
case WARNING:
|
||||
if (cur_level_ <= WARNING) fprintf(stderr, "[W] ");
|
||||
break;
|
||||
case ERROR:
|
||||
if (cur_level_ <= ERROR) fprintf(stderr, "[E] ");
|
||||
break;
|
||||
case FATAL:
|
||||
if (cur_level_ <= FATAL) fprintf(stderr, "[F] ");
|
||||
break;
|
||||
}
|
||||
|
||||
if (cur_level_ <= level_) {
|
||||
fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name);
|
||||
}
|
||||
}
|
||||
|
||||
~Logger() noexcept(false) {
|
||||
static constexpr const char *kErrMsg = R"(
|
||||
Some bad things happened. Please read the above error messages and stack
|
||||
trace. If you are using Python, the following command may be helpful:
|
||||
|
||||
gdb --args python /path/to/your/code.py
|
||||
|
||||
(You can use `gdb` to debug the code. Please consider compiling
|
||||
a debug version of sherpa_onnx.).
|
||||
|
||||
If you are unable to fix it, please open an issue at:
|
||||
|
||||
https://github.com/csukuangfj/kaldi-native-fbank/issues/new
|
||||
)";
|
||||
if (level_ == FATAL) {
|
||||
fprintf(stderr, "\n");
|
||||
std::string stack_trace = GetStackTrace();
|
||||
if (!stack_trace.empty()) {
|
||||
fprintf(stderr, "\n\n%s\n", stack_trace.c_str());
|
||||
}
|
||||
|
||||
fflush(nullptr);
|
||||
|
||||
#ifndef __ANDROID_API__
|
||||
if (EnableAbort()) {
|
||||
// NOTE: abort() will terminate the program immediately without
|
||||
// printing the Python stack backtrace.
|
||||
abort();
|
||||
}
|
||||
|
||||
throw std::runtime_error(kErrMsg);
|
||||
#else
|
||||
abort();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
const Logger &operator<<(bool b) const {
|
||||
if (cur_level_ <= level_) {
|
||||
fprintf(stderr, b ? "true" : "false");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(int8_t i) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(const char *s) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%s", s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(int32_t i) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(uint32_t i) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%u", i);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(uint64_t i) const {
|
||||
if (cur_level_ <= level_)
|
||||
fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(int64_t i) const {
|
||||
if (cur_level_ <= level_)
|
||||
fprintf(stderr, "%lli", (long long int)i); // NOLINT
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(float f) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%f", f);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Logger &operator<<(double d) const {
|
||||
if (cur_level_ <= level_) fprintf(stderr, "%f", d);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const Logger &operator<<(const T &t) const {
|
||||
// require T overloads operator<<
|
||||
std::ostringstream os;
|
||||
os << t;
|
||||
return *this << os.str().c_str();
|
||||
}
|
||||
|
||||
// specialization to fix compile error: `stringstream << nullptr` is ambiguous
|
||||
const Logger &operator<<(const std::nullptr_t &null) const {
|
||||
if (cur_level_ <= level_) *this << "(null)";
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
const char *filename_;
|
||||
const char *func_name_;
|
||||
uint32_t line_num_;
|
||||
LogLevel level_;
|
||||
LogLevel cur_level_;
|
||||
};
|
||||
#endif // SHERPA_ONNX_ENABLE_CHECK
|
||||
|
||||
class Voidifier {
|
||||
public:
|
||||
#if SHERPA_ONNX_ENABLE_CHECK
|
||||
void operator&(const Logger &) const {}
|
||||
#endif
|
||||
};
|
||||
#if !defined(SHERPA_ONNX_ENABLE_CHECK)
|
||||
template <typename T>
|
||||
const Voidifier &operator<<(const Voidifier &v, T &&) {
|
||||
return v;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "")
|
||||
|
||||
#ifdef SHERPA_ONNX_ENABLE_CHECK
|
||||
|
||||
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \
|
||||
defined(__PRETTY_FUNCTION__)
|
||||
// for clang and GCC
|
||||
#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__
|
||||
#else
|
||||
// for other compilers
|
||||
#define SHERPA_ONNX_FUNC __func__
|
||||
#endif
|
||||
|
||||
#define SHERPA_ONNX_CHECK(x) \
|
||||
(x) ? (void)0 \
|
||||
: ::sherpa_onnx::Voidifier() & \
|
||||
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
|
||||
::sherpa_onnx::FATAL) \
|
||||
<< "Check failed: " << #x << " "
|
||||
|
||||
// WARNING: x and y may be evaluated multiple times, but this happens only
|
||||
// when the check fails. Since the program aborts if it fails, we don't think
|
||||
// the extra evaluation of x and y matters.
|
||||
//
|
||||
// CAUTION: we recommend the following use case:
|
||||
//
|
||||
// auto x = Foo();
|
||||
// auto y = Bar();
|
||||
// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message";
|
||||
//
|
||||
// And please avoid
|
||||
//
|
||||
// SHERPA_ONNX_CHECK_EQ(Foo(), Bar());
|
||||
//
|
||||
// if `Foo()` or `Bar()` causes some side effects, e.g., changing some
|
||||
// local static variables or global variables.
|
||||
#define _SHERPA_ONNX_CHECK_OP(x, y, op) \
|
||||
((x)op(y)) ? (void)0 \
|
||||
: ::sherpa_onnx::Voidifier() & \
|
||||
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \
|
||||
::sherpa_onnx::FATAL) \
|
||||
<< "Check failed: " << #x << " " << #op << " " << #y \
|
||||
<< " (" << (x) << " vs. " << (y) << ") "
|
||||
|
||||
#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==)
|
||||
#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=)
|
||||
#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <)
|
||||
#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=)
|
||||
#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >)
|
||||
#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=)
|
||||
|
||||
#define SHERPA_ONNX_LOG(x) \
|
||||
::sherpa_onnx::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_onnx::x)
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// For debug check
|
||||
// ------------------------------------------------------------
|
||||
// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank,
|
||||
// the following macros are in fact empty and does nothing.
|
||||
|
||||
#define SHERPA_ONNX_DCHECK(x) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_EQ(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_NE(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_LT(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_LE(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_GT(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DCHECK_GE(x, y) \
|
||||
::sherpa_onnx::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y)
|
||||
|
||||
#define SHERPA_ONNX_DLOG(x) \
|
||||
::sherpa_onnx::kDisableDebug \
|
||||
? (void)0 \
|
||||
: ::sherpa_onnx::Voidifier() & SHERPA_ONNX_LOG(x)
|
||||
|
||||
#else
|
||||
|
||||
#define SHERPA_ONNX_CHECK(x) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_LOG(x) ::sherpa_onnx::Voidifier()
|
||||
|
||||
#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_onnx::Voidifier()
|
||||
|
||||
#define SHERPA_ONNX_DCHECK(x) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DLOG(x) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_onnx::Voidifier()
|
||||
#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_onnx::Voidifier()
|
||||
|
||||
#endif // SHERPA_ONNX_CHECK_NE
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_LOG_H_
|
||||
@@ -37,7 +37,9 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "OnlineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\")";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -47,7 +49,8 @@ class OnlineRecognizer::Impl {
|
||||
explicit Impl(const OnlineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.tokens) {
|
||||
sym_(config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
}
|
||||
@@ -64,7 +67,7 @@ class OnlineRecognizer::Impl {
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
@@ -111,18 +114,44 @@ class OnlineRecognizer::Impl {
|
||||
}
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
return Convert(decoder_result, sym_);
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const {
|
||||
if (!config_.enable_endpoint) return false;
|
||||
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||
|
||||
// frame shift is 10 milliseconds
|
||||
float frame_shift_in_seconds = 0.01;
|
||||
|
||||
// subsampling factor is 4
|
||||
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
|
||||
|
||||
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||
frame_shift_in_seconds);
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const {
|
||||
// reset result and neural network model state,
|
||||
// but keep the feature extractor state
|
||||
|
||||
// reset result
|
||||
s->SetResult(decoder_->GetEmptyResult());
|
||||
|
||||
// reset neural network model state
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
Endpoint endpoint_;
|
||||
};
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||
@@ -137,12 +166,18 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
|
||||
return impl_->GetResult(s);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
|
||||
return impl_->IsEndpoint(s);
|
||||
}
|
||||
|
||||
void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
@@ -22,13 +23,21 @@ struct OnlineRecognizerConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineTransducerModelConfig model_config;
|
||||
std::string tokens;
|
||||
EndpointConfig endpoint_config;
|
||||
bool enable_endpoint;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const std::string &tokens)
|
||||
: feat_config(feat_config), model_config(model_config), tokens(tokens) {}
|
||||
const std::string &tokens,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
tokens(tokens),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
@@ -48,7 +57,7 @@ class OnlineRecognizer {
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) {
|
||||
void DecodeStream(OnlineStream *s) const {
|
||||
OnlineStream *ss[1] = {s};
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
@@ -58,9 +67,18 @@ class OnlineRecognizer {
|
||||
* @param ss Pointer array containing streams to be decoded.
|
||||
* @param n Number of streams in `ss`.
|
||||
*/
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n);
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const;
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s);
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const;
|
||||
|
||||
// Return true if we detect an endpoint for this stream.
|
||||
// Note: If this function returns true, you usually want to
|
||||
// invoke Reset(s).
|
||||
bool IsEndpoint(OnlineStream *s) const;
|
||||
|
||||
// Clear the state of this stream. If IsEndpoint(s) returns true,
|
||||
// after calling this function, IsEndpoint(s) will return false
|
||||
void Reset(OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
@@ -55,7 +55,8 @@ class OnlineStream {
|
||||
|
||||
int32_t FeatureDim() const;
|
||||
|
||||
// Return a reference to the number of processed frames so far.
|
||||
// Return a reference to the number of processed frames so far
|
||||
// before subsampling..
|
||||
// Initially, it is 0. It is always less than NumFramesReady().
|
||||
//
|
||||
// The returned reference is valid as long as this object is alive.
|
||||
|
||||
@@ -14,6 +14,9 @@ namespace sherpa_onnx {
|
||||
struct OnlineTransducerDecoderResult {
|
||||
/// The decoded token IDs so far
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
};
|
||||
|
||||
class OnlineTransducerDecoder {
|
||||
|
||||
@@ -113,6 +113,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
if (y != 0) {
|
||||
emitted = true;
|
||||
(*result)[i].tokens.push_back(y);
|
||||
(*result)[i].num_trailing_blanks = 0;
|
||||
} else {
|
||||
++(*result)[i].num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
|
||||
774
sherpa-onnx/csrc/parse-options.cc
Normal file
774
sherpa-onnx/csrc/parse-options.cc
Normal file
@@ -0,0 +1,774 @@
|
||||
// sherpa-onnx/csrc/parse-options.cc
|
||||
/**
|
||||
* Copyright 2009-2011 Karel Vesely; Microsoft Corporation;
|
||||
* Saarland University (Author: Arnab Ghoshal);
|
||||
* Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey);
|
||||
* Frantisek Skala; Arnab Ghoshal
|
||||
* Copyright 2013 Tanel Alumae
|
||||
*/
|
||||
|
||||
// This file is copied and modified from kaldi/src/util/parse-options.cu
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
#include <ctype.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \
|
||||
_strtoi64(cur_cstr, end_cstr, 10);
|
||||
#else
|
||||
#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Converts a string into an integer via strtoll and returns false if there was
|
||||
/// any kind of problem (i.e. the string was not an integer or contained extra
|
||||
/// non-whitespace junk, or the integer was too large to fit into the type it is
|
||||
/// being converted into). Only sets *out if everything was OK and it returns
|
||||
/// true.
|
||||
template <class Int>
|
||||
bool ConvertStringToInteger(const std::string &str, Int *out) {
|
||||
// copied from kaldi/src/util/text-util.h
|
||||
static_assert(std::is_integral<Int>::value, "");
|
||||
const char *this_str = str.c_str();
|
||||
char *end = nullptr;
|
||||
errno = 0;
|
||||
int64_t i = SHERPA_ONNX_STRTOLL(this_str, &end);
|
||||
if (end != this_str) {
|
||||
while (isspace(*end)) ++end;
|
||||
}
|
||||
if (end == this_str || *end != '\0' || errno != 0) return false;
|
||||
Int iInt = static_cast<Int>(i);
|
||||
if (static_cast<int64_t>(iInt) != i ||
|
||||
(i < 0 && !std::numeric_limits<Int>::is_signed)) {
|
||||
return false;
|
||||
}
|
||||
*out = iInt;
|
||||
return true;
|
||||
}
|
||||
|
||||
// copied from kaldi/src/util/text-util.cc
|
||||
template <class T>
|
||||
class NumberIstream {
|
||||
public:
|
||||
explicit NumberIstream(std::istream &i) : in_(i) {}
|
||||
|
||||
NumberIstream &operator>>(T &x) {
|
||||
if (!in_.good()) return *this;
|
||||
in_ >> x;
|
||||
if (!in_.fail() && RemainderIsOnlySpaces()) return *this;
|
||||
return ParseOnFail(&x);
|
||||
}
|
||||
|
||||
private:
|
||||
std::istream &in_;
|
||||
|
||||
bool RemainderIsOnlySpaces() {
|
||||
if (in_.tellg() != std::istream::pos_type(-1)) {
|
||||
std::string rem;
|
||||
in_ >> rem;
|
||||
|
||||
if (rem.find_first_not_of(' ') != std::string::npos) {
|
||||
// there is not only spaces
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
in_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
NumberIstream &ParseOnFail(T *x) {
|
||||
std::string str;
|
||||
in_.clear();
|
||||
in_.seekg(0);
|
||||
// If the stream is broken even before trying
|
||||
// to read from it or if there are many tokens,
|
||||
// it's pointless to try.
|
||||
if (!(in_ >> str) || !RemainderIsOnlySpaces()) {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, T> inf_nan_map;
|
||||
// we'll keep just uppercase values.
|
||||
inf_nan_map["INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["+INFINITY"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-INFINITY"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["+NAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-NAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
// MSVC
|
||||
inf_nan_map["1.#INF"] = std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["-1.#INF"] = -std::numeric_limits<T>::infinity();
|
||||
inf_nan_map["1.#QNAN"] = std::numeric_limits<T>::quiet_NaN();
|
||||
inf_nan_map["-1.#QNAN"] = -std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::toupper);
|
||||
|
||||
if (inf_nan_map.find(str) != inf_nan_map.end()) {
|
||||
*x = inf_nan_map[str];
|
||||
} else {
|
||||
in_.setstate(std::ios_base::failbit);
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// ConvertStringToReal converts a string into either float or double
|
||||
/// and returns false if there was any kind of problem (i.e. the string
|
||||
/// was not a floating point number or contained extra non-whitespace junk).
|
||||
/// Be careful- this function will successfully read inf's or nan's.
|
||||
template <typename T>
|
||||
bool ConvertStringToReal(const std::string &str, T *out) {
|
||||
std::istringstream iss(str);
|
||||
|
||||
NumberIstream<T> i(iss);
|
||||
|
||||
i >> *out;
|
||||
|
||||
if (iss.fail()) {
|
||||
// Number conversion failed.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po)
|
||||
: print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) {
|
||||
if (po != nullptr && po->other_parser_ != nullptr) {
|
||||
// we get here if this constructor is used twice, recursively.
|
||||
other_parser_ = po->other_parser_;
|
||||
} else {
|
||||
other_parser_ = po;
|
||||
}
|
||||
if (po != nullptr && po->prefix_ != "") {
|
||||
prefix_ = po->prefix_ + std::string(".") + prefix;
|
||||
} else {
|
||||
prefix_ = prefix;
|
||||
}
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, bool *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, int32_t *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, uint32_t *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, float *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, double *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
void ParseOptions::Register(const std::string &name, std::string *ptr,
|
||||
const std::string &doc) {
|
||||
RegisterTmpl(name, ptr, doc);
|
||||
}
|
||||
|
||||
// old-style, used for registering application-specific parameters
|
||||
template <typename T>
|
||||
void ParseOptions::RegisterTmpl(const std::string &name, T *ptr,
|
||||
const std::string &doc) {
|
||||
if (other_parser_ == nullptr) {
|
||||
this->RegisterCommon(name, ptr, doc, false);
|
||||
} else {
|
||||
SHERPA_ONNX_CHECK(prefix_ != "")
|
||||
<< "prefix: " << prefix_ << "\n"
|
||||
<< "Cannot use empty prefix when registering with prefix.";
|
||||
std::string new_name = prefix_ + '.' + name; // name becomes prefix.name
|
||||
other_parser_->Register(new_name, ptr, doc);
|
||||
}
|
||||
}
|
||||
|
||||
// does the common part of the job of registering a parameter
|
||||
template <typename T>
|
||||
void ParseOptions::RegisterCommon(const std::string &name, T *ptr,
|
||||
const std::string &doc, bool is_standard) {
|
||||
SHERPA_ONNX_CHECK(ptr != nullptr);
|
||||
std::string idx = name;
|
||||
NormalizeArgName(&idx);
|
||||
if (doc_map_.find(idx) != doc_map_.end()) {
|
||||
SHERPA_ONNX_LOG(WARNING)
|
||||
<< "Registering option twice, ignoring second time: " << name;
|
||||
} else {
|
||||
this->RegisterSpecific(name, idx, ptr, doc, is_standard);
|
||||
}
|
||||
}
|
||||
|
||||
// used to register standard parameters (those that are present in all of the
|
||||
// applications)
|
||||
template <typename T>
|
||||
void ParseOptions::RegisterStandard(const std::string &name, T *ptr,
|
||||
const std::string &doc) {
|
||||
this->RegisterCommon(name, ptr, doc, true);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, bool *b,
|
||||
const std::string &doc, bool is_standard) {
|
||||
bool_map_[idx] = b;
|
||||
doc_map_[idx] =
|
||||
DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"),
|
||||
is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, int32_t *i,
|
||||
const std::string &doc, bool is_standard) {
|
||||
int_map_[idx] = i;
|
||||
std::ostringstream ss;
|
||||
ss << doc << " (int, default = " << *i << ")";
|
||||
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, uint32_t *u,
|
||||
const std::string &doc, bool is_standard) {
|
||||
uint_map_[idx] = u;
|
||||
std::ostringstream ss;
|
||||
ss << doc << " (uint, default = " << *u << ")";
|
||||
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, float *f,
|
||||
const std::string &doc, bool is_standard) {
|
||||
float_map_[idx] = f;
|
||||
std::ostringstream ss;
|
||||
ss << doc << " (float, default = " << *f << ")";
|
||||
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, double *f,
|
||||
const std::string &doc, bool is_standard) {
|
||||
double_map_[idx] = f;
|
||||
std::ostringstream ss;
|
||||
ss << doc << " (double, default = " << *f << ")";
|
||||
doc_map_[idx] = DocInfo(name, ss.str(), is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::RegisterSpecific(const std::string &name,
|
||||
const std::string &idx, std::string *s,
|
||||
const std::string &doc, bool is_standard) {
|
||||
string_map_[idx] = s;
|
||||
doc_map_[idx] =
|
||||
DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard);
|
||||
}
|
||||
|
||||
void ParseOptions::DisableOption(const std::string &name) {
|
||||
if (argv_ != nullptr) {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "DisableOption must not be called after calling Read().";
|
||||
}
|
||||
if (doc_map_.erase(name) == 0) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Option " << name
|
||||
<< " was not registered so cannot be disabled: ";
|
||||
}
|
||||
bool_map_.erase(name);
|
||||
int_map_.erase(name);
|
||||
uint_map_.erase(name);
|
||||
float_map_.erase(name);
|
||||
double_map_.erase(name);
|
||||
string_map_.erase(name);
|
||||
}
|
||||
|
||||
int ParseOptions::NumArgs() const { return positional_args_.size(); }
|
||||
|
||||
std::string ParseOptions::GetArg(int i) const {
|
||||
if (i < 1 || i > static_cast<int>(positional_args_.size())) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i;
|
||||
}
|
||||
|
||||
return positional_args_[i - 1];
|
||||
}
|
||||
|
||||
// We currently do not support any other options.
|
||||
enum ShellType { kBash = 0 };
|
||||
|
||||
// This can be changed in the code if it ever does need to be changed (as it's
|
||||
// unlikely that one compilation of this tool-set would use both shells).
|
||||
static ShellType kShellType = kBash;
|
||||
|
||||
// Returns true if we need to escape a string before putting it into
|
||||
// a shell (mainly thinking of bash shell, but should work for others)
|
||||
// This is for the convenience of the user so command-lines that are
|
||||
// printed out by ParseOptions::Read (with --print-args=true) are
|
||||
// paste-able into the shell and will run. If you use a different type of
|
||||
// shell, it might be necessary to change this function.
|
||||
// But it's mostly a cosmetic issue as it basically affects how
|
||||
// the program echoes its command-line arguments to the screen.
|
||||
static bool MustBeQuoted(const std::string &str, ShellType st) {
|
||||
// Only Bash is supported (for the moment).
|
||||
SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
|
||||
|
||||
const char *c = str.c_str();
|
||||
if (*c == '\0') {
|
||||
return true; // Must quote empty string
|
||||
} else {
|
||||
const char *ok_chars[2];
|
||||
|
||||
// These seem not to be interpreted as long as there are no other "bad"
|
||||
// characters involved (e.g. "," would be interpreted as part of something
|
||||
// like a{b,c}, but not on its own.
|
||||
ok_chars[kBash] = "[]~#^_-+=:.,/";
|
||||
|
||||
// Just want to make sure that a space character doesn't get automatically
|
||||
// inserted here via an automated style-checking script, like it did before.
|
||||
SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' '));
|
||||
|
||||
for (; *c != '\0'; ++c) {
|
||||
// For non-alphanumeric characters we have a list of characters which
|
||||
// are OK. All others are forbidden (this is easier since the shell
|
||||
// interprets most non-alphanumeric characters).
|
||||
if (!isalnum(*c)) {
|
||||
const char *d;
|
||||
for (d = ok_chars[st]; *d != '\0'; ++d) {
|
||||
if (*c == *d) break;
|
||||
}
|
||||
// If not alphanumeric or one of the "ok_chars", it must be escaped.
|
||||
if (*d == '\0') return true;
|
||||
}
|
||||
}
|
||||
return false; // The string was OK. No quoting or escaping.
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a quoted and escaped version of "str"
|
||||
// which has previously been determined to need escaping.
|
||||
// Our aim is to print out the command line in such a way that if it's
|
||||
// pasted into a shell of ShellType "st" (only bash for now), it
|
||||
// will get passed to the program in the same way.
|
||||
static std::string QuoteAndEscape(const std::string &str, ShellType st) {
|
||||
// Only Bash is supported (for the moment).
|
||||
SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type.";
|
||||
|
||||
// For now we use the following rules:
|
||||
// In the normal case, we quote with single-quote "'", and to escape
|
||||
// a single-quote we use the string: '\'' (interpreted as closing the
|
||||
// single-quote, putting an escaped single-quote from the shell, and
|
||||
// then reopening the single quote).
|
||||
char quote_char = '\'';
|
||||
const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b
|
||||
|
||||
// If the string contains single-quotes that would need escaping this
|
||||
// way, and we determine that the string could be safely double-quoted
|
||||
// without requiring any escaping, then we double-quote the string.
|
||||
// This is the case if the characters "`$\ do not appear in the string.
|
||||
// e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html
|
||||
const char *c_str = str.c_str();
|
||||
if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) {
|
||||
quote_char = '"';
|
||||
escape_str = "\\\""; // should never be accessed.
|
||||
}
|
||||
|
||||
char buf[2];
|
||||
buf[1] = '\0';
|
||||
|
||||
buf[0] = quote_char;
|
||||
std::string ans = buf;
|
||||
const char *c = str.c_str();
|
||||
for (; *c != '\0'; ++c) {
|
||||
if (*c == quote_char) {
|
||||
ans += escape_str;
|
||||
} else {
|
||||
buf[0] = *c;
|
||||
ans += buf;
|
||||
}
|
||||
}
|
||||
buf[0] = quote_char;
|
||||
ans += buf;
|
||||
return ans;
|
||||
}
|
||||
|
||||
// static function
|
||||
std::string ParseOptions::Escape(const std::string &str) {
|
||||
return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str;
|
||||
}
|
||||
|
||||
int ParseOptions::Read(int argc, const char *const argv[]) {
|
||||
argc_ = argc;
|
||||
argv_ = argv;
|
||||
std::string key, value;
|
||||
int i;
|
||||
|
||||
// first pass: look for config parameter, look for priority
|
||||
for (i = 1; i < argc; ++i) {
|
||||
if (std::strncmp(argv[i], "--", 2) == 0) {
|
||||
if (std::strcmp(argv[i], "--") == 0) {
|
||||
// a lone "--" marks the end of named options
|
||||
break;
|
||||
}
|
||||
bool has_equal_sign;
|
||||
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
|
||||
NormalizeArgName(&key);
|
||||
Trim(&value);
|
||||
if (key.compare("config") == 0) {
|
||||
ReadConfigFile(value);
|
||||
} else if (key.compare("help") == 0) {
|
||||
PrintUsage();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool double_dash_seen = false;
|
||||
// second pass: add the command line options
|
||||
for (i = 1; i < argc; ++i) {
|
||||
if (std::strncmp(argv[i], "--", 2) == 0) {
|
||||
if (std::strcmp(argv[i], "--") == 0) {
|
||||
// A lone "--" marks the end of named options.
|
||||
// Skip that option and break the processing of named options
|
||||
i += 1;
|
||||
double_dash_seen = true;
|
||||
break;
|
||||
}
|
||||
bool has_equal_sign;
|
||||
SplitLongArg(argv[i], &key, &value, &has_equal_sign);
|
||||
NormalizeArgName(&key);
|
||||
Trim(&value);
|
||||
if (!SetOption(key, value, has_equal_sign)) {
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << argv[i];
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// process remaining arguments as positional
|
||||
for (; i < argc; ++i) {
|
||||
if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) {
|
||||
double_dash_seen = true;
|
||||
} else {
|
||||
positional_args_.push_back(std::string(argv[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// if the user did not suppress this with --print-args = false....
|
||||
if (print_args_) {
|
||||
std::ostringstream strm;
|
||||
for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " ";
|
||||
strm << '\n';
|
||||
SHERPA_ONNX_LOG(INFO) << strm.str();
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const {
|
||||
std::ostringstream os;
|
||||
os << '\n' << usage_ << '\n';
|
||||
// first we print application-specific options
|
||||
bool app_specific_header_printed = false;
|
||||
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
|
||||
if (it->second.is_standard_ == false) { // application-specific option
|
||||
if (app_specific_header_printed == false) { // header was not yet printed
|
||||
os << "Options:" << '\n';
|
||||
app_specific_header_printed = true;
|
||||
}
|
||||
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
|
||||
<< it->second.use_msg_ << '\n';
|
||||
}
|
||||
}
|
||||
if (app_specific_header_printed == true) {
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
// then the standard options
|
||||
os << "Standard options:" << '\n';
|
||||
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
|
||||
if (it->second.is_standard_ == true) { // we have standard option
|
||||
os << " --" << std::setw(25) << std::left << it->second.name_ << " : "
|
||||
<< it->second.use_msg_ << '\n';
|
||||
}
|
||||
}
|
||||
os << '\n';
|
||||
if (print_command_line) {
|
||||
std::ostringstream strm;
|
||||
strm << "Command line was: ";
|
||||
for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " ";
|
||||
strm << '\n';
|
||||
os << strm.str();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(INFO) << os.str();
|
||||
}
|
||||
|
||||
void ParseOptions::PrintConfig(std::ostream &os) const {
|
||||
os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n';
|
||||
std::string key;
|
||||
for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) {
|
||||
key = it->first;
|
||||
os << it->second.name_ << " = ";
|
||||
if (bool_map_.end() != bool_map_.find(key)) {
|
||||
os << (*bool_map_.at(key) ? "true" : "false");
|
||||
} else if (int_map_.end() != int_map_.find(key)) {
|
||||
os << (*int_map_.at(key));
|
||||
} else if (uint_map_.end() != uint_map_.find(key)) {
|
||||
os << (*uint_map_.at(key));
|
||||
} else if (float_map_.end() != float_map_.find(key)) {
|
||||
os << (*float_map_.at(key));
|
||||
} else if (double_map_.end() != double_map_.find(key)) {
|
||||
os << (*double_map_.at(key));
|
||||
} else if (string_map_.end() != string_map_.find(key)) {
|
||||
os << "'" << *string_map_.at(key) << "'";
|
||||
} else {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "PrintConfig: unrecognized option " << key << "[code error]";
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
void ParseOptions::ReadConfigFile(const std::string &filename) {
|
||||
std::ifstream is(filename.c_str(), std::ifstream::in);
|
||||
if (!is.good()) {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Cannot open config file: " << filename;
|
||||
}
|
||||
|
||||
std::string line, key, value;
|
||||
int32_t line_number = 0;
|
||||
while (std::getline(is, line)) {
|
||||
++line_number;
|
||||
// trim out the comments
|
||||
size_t pos;
|
||||
if ((pos = line.find_first_of('#')) != std::string::npos) {
|
||||
line.erase(pos);
|
||||
}
|
||||
// skip empty lines
|
||||
Trim(&line);
|
||||
if (line.length() == 0) continue;
|
||||
|
||||
if (line.substr(0, 2) != "--") {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Reading config file " << filename << ": line " << line_number
|
||||
<< " does not look like a line "
|
||||
<< "from a Kaldi command-line program's config file: should "
|
||||
<< "be of the form --x=y. Note: config files intended to "
|
||||
<< "be sourced by shell scripts lack the '--'.";
|
||||
}
|
||||
|
||||
// parse option
|
||||
bool has_equal_sign;
|
||||
SplitLongArg(line, &key, &value, &has_equal_sign);
|
||||
NormalizeArgName(&key);
|
||||
Trim(&value);
|
||||
if (!SetOption(key, value, has_equal_sign)) {
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option " << line << " in config file "
|
||||
<< filename << ": line " << line_number;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ParseOptions::SplitLongArg(const std::string &in, std::string *key,
|
||||
std::string *value,
|
||||
bool *has_equal_sign) const {
|
||||
SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition.
|
||||
size_t pos = in.find_first_of('=', 0);
|
||||
if (pos == std::string::npos) { // we allow --option for bools
|
||||
// defaults to empty. We handle this differently in different cases.
|
||||
*key = in.substr(2, in.size() - 2); // 2 because starts with --.
|
||||
*value = "";
|
||||
*has_equal_sign = false;
|
||||
} else if (pos == 2) { // we also don't allow empty keys: --=value
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option (no key): " << in;
|
||||
} else { // normal case: --option=value
|
||||
*key = in.substr(2, pos - 2); // 2 because starts with --.
|
||||
*value = in.substr(pos + 1);
|
||||
*has_equal_sign = true;
|
||||
}
|
||||
}
|
||||
|
||||
void ParseOptions::NormalizeArgName(std::string *str) const {
|
||||
std::string out;
|
||||
std::string::iterator it;
|
||||
|
||||
for (it = str->begin(); it != str->end(); ++it) {
|
||||
if (*it == '_') {
|
||||
out += '-'; // convert _ to -
|
||||
} else {
|
||||
out += std::tolower(*it);
|
||||
}
|
||||
}
|
||||
*str = out;
|
||||
|
||||
SHERPA_ONNX_CHECK_GT(str->length(), 0);
|
||||
}
|
||||
|
||||
void ParseOptions::Trim(std::string *str) const {
|
||||
const char *white_chars = " \t\n\r\f\v";
|
||||
|
||||
std::string::size_type pos = str->find_last_not_of(white_chars);
|
||||
if (pos != std::string::npos) {
|
||||
str->erase(pos + 1);
|
||||
pos = str->find_first_not_of(white_chars);
|
||||
if (pos != std::string::npos) str->erase(0, pos);
|
||||
} else {
|
||||
str->erase(str->begin(), str->end());
|
||||
}
|
||||
}
|
||||
|
||||
bool ParseOptions::SetOption(const std::string &key, const std::string &value,
|
||||
bool has_equal_sign) {
|
||||
if (bool_map_.end() != bool_map_.find(key)) {
|
||||
if (has_equal_sign && value == "") {
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid option --" << key << "=";
|
||||
}
|
||||
*(bool_map_[key]) = ToBool(value);
|
||||
} else if (int_map_.end() != int_map_.find(key)) {
|
||||
*(int_map_[key]) = ToInt(value);
|
||||
} else if (uint_map_.end() != uint_map_.find(key)) {
|
||||
*(uint_map_[key]) = ToUint(value);
|
||||
} else if (float_map_.end() != float_map_.find(key)) {
|
||||
*(float_map_[key]) = ToFloat(value);
|
||||
} else if (double_map_.end() != double_map_.find(key)) {
|
||||
*(double_map_[key]) = ToDouble(value);
|
||||
} else if (string_map_.end() != string_map_.find(key)) {
|
||||
if (!has_equal_sign) {
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Invalid option --" << key << " (option format is --x=y).";
|
||||
}
|
||||
*(string_map_[key]) = value;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseOptions::ToBool(std::string str) const {
|
||||
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
|
||||
|
||||
// allow "" as a valid option for "true", so that --x is the same as --x=true
|
||||
if ((str.compare("true") == 0) || (str.compare("t") == 0) ||
|
||||
(str.compare("1") == 0) || (str.compare("") == 0)) {
|
||||
return true;
|
||||
}
|
||||
if ((str.compare("false") == 0) || (str.compare("f") == 0) ||
|
||||
(str.compare("0") == 0)) {
|
||||
return false;
|
||||
}
|
||||
// if it is neither true nor false:
|
||||
PrintUsage(true);
|
||||
SHERPA_ONNX_LOG(FATAL)
|
||||
<< "Invalid format for boolean argument [expected true or false]: "
|
||||
<< str;
|
||||
return false; // never reached
|
||||
}
|
||||
|
||||
int32_t ParseOptions::ToInt(const std::string &str) const {
|
||||
int32_t ret = 0;
|
||||
if (!ConvertStringToInteger(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint32_t ParseOptions::ToUint(const std::string &str) const {
|
||||
uint32_t ret = 0;
|
||||
if (!ConvertStringToInteger(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid integer option \"" << str << "\"";
|
||||
return ret;
|
||||
}
|
||||
|
||||
float ParseOptions::ToFloat(const std::string &str) const {
|
||||
float ret;
|
||||
if (!ConvertStringToReal(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
|
||||
return ret;
|
||||
}
|
||||
|
||||
double ParseOptions::ToDouble(const std::string &str) const {
|
||||
double ret;
|
||||
if (!ConvertStringToReal(str, &ret))
|
||||
SHERPA_ONNX_LOG(FATAL) << "Invalid floating-point option \"" << str << "\"";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// instantiate templates
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterTmpl(const std::string &name,
|
||||
std::string *ptr,
|
||||
const std::string &doc);
|
||||
|
||||
template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterStandard(const std::string &name,
|
||||
int32_t *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterStandard(const std::string &name,
|
||||
uint32_t *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterStandard(const std::string &name,
|
||||
float *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterStandard(const std::string &name,
|
||||
double *ptr,
|
||||
const std::string &doc);
|
||||
template void ParseOptions::RegisterStandard(const std::string &name,
|
||||
std::string *ptr,
|
||||
const std::string &doc);
|
||||
|
||||
template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr,
|
||||
const std::string &doc,
|
||||
bool is_standard);
|
||||
template void ParseOptions::RegisterCommon(const std::string &name,
|
||||
int32_t *ptr, const std::string &doc,
|
||||
bool is_standard);
|
||||
template void ParseOptions::RegisterCommon(const std::string &name,
|
||||
uint32_t *ptr,
|
||||
const std::string &doc,
|
||||
bool is_standard);
|
||||
template void ParseOptions::RegisterCommon(const std::string &name, float *ptr,
|
||||
const std::string &doc,
|
||||
bool is_standard);
|
||||
template void ParseOptions::RegisterCommon(const std::string &name, double *ptr,
|
||||
const std::string &doc,
|
||||
bool is_standard);
|
||||
template void ParseOptions::RegisterCommon(const std::string &name,
|
||||
std::string *ptr,
|
||||
const std::string &doc,
|
||||
bool is_standard);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
252
sherpa-onnx/csrc/parse-options.h
Normal file
252
sherpa-onnx/csrc/parse-options.h
Normal file
@@ -0,0 +1,252 @@
|
||||
// sherpa-onnx/csrc/parse-options.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
//
|
||||
// This file is copied and modified from kaldi/src/util/parse-options.h
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
|
||||
#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class ParseOptions {
|
||||
public:
|
||||
explicit ParseOptions(const char *usage)
|
||||
: print_args_(true),
|
||||
help_(false),
|
||||
usage_(usage),
|
||||
argc_(0),
|
||||
argv_(nullptr),
|
||||
prefix_(""),
|
||||
other_parser_(nullptr) {
|
||||
#if !defined(_MSC_VER) && !defined(__CYGWIN__)
|
||||
// This is just a convenient place to set the stderr to line
|
||||
// buffering mode, since it's called at program start.
|
||||
// This helps ensure different programs' output is not mixed up.
|
||||
setlinebuf(stderr);
|
||||
#endif
|
||||
RegisterStandard("config", &config_,
|
||||
"Configuration file to read (this "
|
||||
"option may be repeated)");
|
||||
RegisterStandard("print-args", &print_args_,
|
||||
"Print the command line arguments (to stderr)");
|
||||
RegisterStandard("help", &help_, "Print out usage message");
|
||||
}
|
||||
|
||||
/**
|
||||
This is a constructor for the special case where some options are
|
||||
registered with a prefix to avoid conflicts. The object thus created will
|
||||
only be used temporarily to register an options class with the original
|
||||
options parser (which is passed as the *other pointer) using the given
|
||||
prefix. It should not be used for any other purpose, and the prefix must
|
||||
not be the empty string. It seems to be the least bad way of implementing
|
||||
options with prefixes at this point.
|
||||
Example of usage is:
|
||||
ParseOptions po; // original ParseOptions object
|
||||
ParseOptions po_mfcc("mfcc", &po); // object with prefix.
|
||||
MfccOptions mfcc_opts;
|
||||
mfcc_opts.Register(&po_mfcc);
|
||||
The options will now get registered as, e.g., --mfcc.frame-shift=10.0
|
||||
instead of just --frame-shift=10.0
|
||||
*/
|
||||
ParseOptions(const std::string &prefix, ParseOptions *other);
|
||||
|
||||
ParseOptions(const ParseOptions &) = delete;
|
||||
ParseOptions &operator=(const ParseOptions &) = delete;
|
||||
~ParseOptions() = default;
|
||||
|
||||
void Register(const std::string &name, bool *ptr, const std::string &doc);
|
||||
void Register(const std::string &name, int32_t *ptr, const std::string &doc);
|
||||
void Register(const std::string &name, uint32_t *ptr, const std::string &doc);
|
||||
void Register(const std::string &name, float *ptr, const std::string &doc);
|
||||
void Register(const std::string &name, double *ptr, const std::string &doc);
|
||||
void Register(const std::string &name, std::string *ptr,
|
||||
const std::string &doc);
|
||||
|
||||
/// If called after registering an option and before calling
|
||||
/// Read(), disables that option from being used. Will crash
|
||||
/// at runtime if that option had not been registered.
|
||||
void DisableOption(const std::string &name);
|
||||
|
||||
/// This one is used for registering standard parameters of all the programs
|
||||
template <typename T>
|
||||
void RegisterStandard(const std::string &name, T *ptr,
|
||||
const std::string &doc);
|
||||
|
||||
/**
|
||||
Parses the command line options and fills the ParseOptions-registered
|
||||
variables. This must be called after all the variables were registered!!!
|
||||
|
||||
Initially the variables have implicit values,
|
||||
then the config file values are set-up,
|
||||
finally the command line values given.
|
||||
Returns the first position in argv that was not used.
|
||||
[typically not useful: use NumParams() and GetParam(). ]
|
||||
*/
|
||||
int Read(int argc, const char *const *argv);
|
||||
|
||||
/// Prints the usage documentation [provided in the constructor].
|
||||
void PrintUsage(bool print_command_line = false) const;
|
||||
|
||||
/// Prints the actual configuration of all the registered variables
|
||||
void PrintConfig(std::ostream &os) const;
|
||||
|
||||
/// Reads the options values from a config file. Must be called after
|
||||
/// registering all options. This is usually used internally after the
|
||||
/// standard --config option is used, but it may also be called from a
|
||||
/// program.
|
||||
void ReadConfigFile(const std::string &filename);
|
||||
|
||||
/// Number of positional parameters (c.f. argc-1).
|
||||
int NumArgs() const;
|
||||
|
||||
/// Returns one of the positional parameters; 1-based indexing for argc/argv
|
||||
/// compatibility. Will crash if param is not >=1 and <=NumArgs().
|
||||
///
|
||||
/// Note: Index is 1 based.
|
||||
std::string GetArg(int param) const;
|
||||
|
||||
std::string GetOptArg(int param) const {
|
||||
return (param <= NumArgs() ? GetArg(param) : "");
|
||||
}
|
||||
|
||||
/// The following function will return a possibly quoted and escaped
|
||||
/// version of "str", according to the current shell. Currently
|
||||
/// this is just hardwired to bash. It's useful for debug output.
|
||||
static std::string Escape(const std::string &str);
|
||||
|
||||
private:
|
||||
/// Template to register various variable types,
|
||||
/// used for program-specific parameters
|
||||
template <typename T>
|
||||
void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc);
|
||||
|
||||
// Following functions do just the datatype-specific part of the job
|
||||
/// Register boolean variable
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
bool *b, const std::string &doc, bool is_standard);
|
||||
/// Register int32_t variable
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
int32_t *i, const std::string &doc, bool is_standard);
|
||||
/// Register unsigned int32_t variable
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
uint32_t *u, const std::string &doc, bool is_standard);
|
||||
/// Register float variable
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
float *f, const std::string &doc, bool is_standard);
|
||||
/// Register double variable [useful as we change BaseFloat type].
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
double *f, const std::string &doc, bool is_standard);
|
||||
/// Register string variable
|
||||
void RegisterSpecific(const std::string &name, const std::string &idx,
|
||||
std::string *s, const std::string &doc,
|
||||
bool is_standard);
|
||||
|
||||
/// Does the actual job for both kinds of parameters
|
||||
/// Does the common part of the job for all datatypes,
|
||||
/// then calls RegisterSpecific
|
||||
template <typename T>
|
||||
void RegisterCommon(const std::string &name, T *ptr, const std::string &doc,
|
||||
bool is_standard);
|
||||
|
||||
/// Set option with name "key" to "value"; will crash if can't do it.
|
||||
/// "has_equal_sign" is used to allow --x for a boolean option x,
|
||||
/// and --y=, for a string option y.
|
||||
bool SetOption(const std::string &key, const std::string &value,
|
||||
bool has_equal_sign);
|
||||
|
||||
bool ToBool(std::string str) const;
|
||||
int32_t ToInt(const std::string &str) const;
|
||||
uint32_t ToUint(const std::string &str) const;
|
||||
float ToFloat(const std::string &str) const;
|
||||
double ToDouble(const std::string &str) const;
|
||||
|
||||
// maps for option variables
|
||||
std::unordered_map<std::string, bool *> bool_map_;
|
||||
std::unordered_map<std::string, int32_t *> int_map_;
|
||||
std::unordered_map<std::string, uint32_t *> uint_map_;
|
||||
std::unordered_map<std::string, float *> float_map_;
|
||||
std::unordered_map<std::string, double *> double_map_;
|
||||
std::unordered_map<std::string, std::string *> string_map_;
|
||||
|
||||
/**
|
||||
Structure for options' documentation
|
||||
*/
|
||||
struct DocInfo {
|
||||
DocInfo() = default;
|
||||
DocInfo(const std::string &name, const std::string &usemsg)
|
||||
: name_(name), use_msg_(usemsg), is_standard_(false) {}
|
||||
DocInfo(const std::string &name, const std::string &usemsg,
|
||||
bool is_standard)
|
||||
: name_(name), use_msg_(usemsg), is_standard_(is_standard) {}
|
||||
|
||||
std::string name_;
|
||||
std::string use_msg_;
|
||||
bool is_standard_;
|
||||
};
|
||||
using DocMapType = std::unordered_map<std::string, DocInfo>;
|
||||
DocMapType doc_map_; ///< map for the documentation
|
||||
|
||||
bool print_args_; ///< variable for the implicit --print-args parameter
|
||||
bool help_; ///< variable for the implicit --help parameter
|
||||
std::string config_; ///< variable for the implicit --config parameter
|
||||
std::vector<std::string> positional_args_;
|
||||
const char *usage_;
|
||||
int argc_;
|
||||
const char *const *argv_;
|
||||
|
||||
/// These members are not normally used. They are only used when the object
|
||||
/// is constructed with a prefix
|
||||
std::string prefix_;
|
||||
ParseOptions *other_parser_;
|
||||
|
||||
protected:
|
||||
/// SplitLongArg parses an argument of the form --a=b, --a=, or --a,
|
||||
/// and sets "has_equal_sign" to true if an equals-sign was parsed..
|
||||
/// this is needed in order to correctly allow --x for a boolean option
|
||||
/// x, and --y= for a string option y, and to disallow --x= and --y.
|
||||
void SplitLongArg(const std::string &in, std::string *key, std::string *value,
|
||||
bool *has_equal_sign) const;
|
||||
|
||||
void NormalizeArgName(std::string *str) const;
|
||||
|
||||
/// Removes the beginning and trailing whitespaces from a string
|
||||
void Trim(std::string *str) const;
|
||||
};
|
||||
|
||||
/// This template is provided for convenience in reading config classes from
|
||||
/// files; this is not the standard way to read configuration options, but may
|
||||
/// occasionally be needed. This function assumes the config has a function
|
||||
/// "void Register(ParseOptions *opts)" which it can call to register the
|
||||
/// ParseOptions object.
|
||||
template <class C>
|
||||
void ReadConfigFromFile(const std::string &config_filename, C *c) {
|
||||
std::ostringstream usage_str;
|
||||
usage_str << "Parsing config from "
|
||||
<< "from '" << config_filename << "'";
|
||||
ParseOptions po(usage_str.str().c_str());
|
||||
c->Register(&po);
|
||||
po.ReadConfigFile(config_filename);
|
||||
}
|
||||
|
||||
/// This variant of the template ReadConfigFromFile is for if you need to read
|
||||
/// two config classes from the same file.
|
||||
template <class C1, class C2>
|
||||
void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) {
|
||||
std::ostringstream usage_str;
|
||||
usage_str << "Parsing config from "
|
||||
<< "from '" << conf << "'";
|
||||
ParseOptions po(usage_str.str().c_str());
|
||||
c1->Register(&po);
|
||||
c2->Register(&po);
|
||||
po.ReadConfigFile(conf);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_
|
||||
142
sherpa-onnx/csrc/sherpa-onnx-alsa.cc
Normal file
142
sherpa-onnx/csrc/sherpa-onnx-alsa.cc
Normal file
@@ -0,0 +1,142 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-alsa.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
#include <signal.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype> // std::tolower
|
||||
#include <cstdint>
|
||||
|
||||
#include "sherpa-onnx/csrc/alsa.h"
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
|
||||
bool stop = false;
|
||||
|
||||
static void Handler(int sig) {
|
||||
stop = true;
|
||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-alsa \
|
||||
/path/to/tokens.txt \
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
device_name \
|
||||
[num_threads]
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
|
||||
The device name specifies which microphone to use in case there are several
|
||||
on you system. You can use
|
||||
|
||||
arecord -l
|
||||
|
||||
to find all available microphones on your computer. For instance, if it outputs
|
||||
|
||||
**** List of CAPTURE Hardware Devices ****
|
||||
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
|
||||
Subdevices: 1/1
|
||||
Subdevice #0: subdevice #0
|
||||
|
||||
and if you want to select card 3 and the device 0 on that card, please use:
|
||||
|
||||
hw:3,0
|
||||
|
||||
as the device_name.
|
||||
)usage";
|
||||
|
||||
fprintf(stderr, "%s\n", usage);
|
||||
fprintf(stderr, "argc, %d\n", argc);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
signal(SIGINT, Handler);
|
||||
|
||||
sherpa_onnx::OnlineRecognizerConfig config;
|
||||
|
||||
config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
config.model_config.decoder_filename = argv[3];
|
||||
config.model_config.joiner_filename = argv[4];
|
||||
|
||||
const char *device_name = argv[5];
|
||||
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
config.enable_endpoint = true;
|
||||
|
||||
config.endpoint_config.rule1.min_trailing_silence = 2.4;
|
||||
config.endpoint_config.rule2.min_trailing_silence = 1.2;
|
||||
config.endpoint_config.rule3.min_utterance_length = 300;
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
|
||||
int32_t expected_sample_rate = config.feat_config.sampling_rate;
|
||||
|
||||
sherpa_onnx::Alsa alsa(device_name);
|
||||
fprintf(stderr, "Use recording device: %s\n", device_name);
|
||||
|
||||
if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
|
||||
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
|
||||
expected_sample_rate);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
|
||||
|
||||
std::string last_text;
|
||||
|
||||
auto stream = recognizer.CreateStream();
|
||||
|
||||
sherpa_onnx::Display display;
|
||||
|
||||
int32_t segment_index = 0;
|
||||
while (!stop) {
|
||||
const std::vector<float> samples = alsa.Read(chunk);
|
||||
|
||||
stream->AcceptWaveform(expected_sample_rate, samples.data(),
|
||||
samples.size());
|
||||
|
||||
while (recognizer.IsReady(stream.get())) {
|
||||
recognizer.DecodeStream(stream.get());
|
||||
}
|
||||
|
||||
auto text = recognizer.GetResult(stream.get()).text;
|
||||
|
||||
bool is_endpoint = recognizer.IsEndpoint(stream.get());
|
||||
|
||||
if (!text.empty() && last_text != text) {
|
||||
last_text = text;
|
||||
|
||||
std::transform(text.begin(), text.end(), text.begin(),
|
||||
[](auto c) { return std::tolower(c); });
|
||||
|
||||
display.Print(segment_index, text);
|
||||
}
|
||||
|
||||
if (!text.empty() && is_endpoint) {
|
||||
++segment_index;
|
||||
recognizer.Reset(stream.get());
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
features.cc
|
||||
online-transducer-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
endpoint.cc
|
||||
online-stream.cc
|
||||
online-recognizer.cc
|
||||
)
|
||||
|
||||
100
sherpa-onnx/python/csrc/endpoint.cc
Normal file
100
sherpa-onnx/python/csrc/endpoint.cc
Normal file
@@ -0,0 +1,100 @@
|
||||
// sherpa-onnx/csrc/endpoint.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static constexpr const char *kEndpointRuleInitDoc = R"doc(
|
||||
Constructor for EndpointRule.
|
||||
|
||||
Args:
|
||||
must_contain_nonsilence:
|
||||
If True, for this endpointing rule to apply there must be nonsilence in the
|
||||
best-path traceback. For decoding, a non-blank token is considered as
|
||||
non-silence.
|
||||
min_trailing_silence:
|
||||
This endpointing rule requires duration of trailing silence (in seconds)
|
||||
to be ``>=`` this value.
|
||||
min_utterance_length:
|
||||
This endpointing rule requires utterance-length (in seconds) to
|
||||
be ``>=`` this value.
|
||||
)doc";
|
||||
|
||||
static constexpr const char *kEndpointConfigInitDoc = R"doc(
|
||||
If any rule in EndpointConfig is activated, it is said that an endpointing
|
||||
is detected.
|
||||
|
||||
Args:
|
||||
rule1:
|
||||
By default, it times out after 2.4 seconds of silence, even if
|
||||
we decoded nothing.
|
||||
rule2:
|
||||
By default, it times out after 1.2 seconds of silence after decoding
|
||||
something.
|
||||
rule3:
|
||||
By default, it times out after the utterance is 20 seconds long, regardless of
|
||||
anything else.
|
||||
)doc";
|
||||
|
||||
static void PybindEndpointRule(py::module *m) {
|
||||
using PyClass = EndpointRule;
|
||||
py::class_<PyClass>(*m, "EndpointRule")
|
||||
.def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
|
||||
py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
|
||||
kEndpointRuleInitDoc)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def_readwrite("must_contain_nonsilence",
|
||||
&PyClass::must_contain_nonsilence)
|
||||
.def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
|
||||
.def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
|
||||
}
|
||||
|
||||
static void PybindEndpointConfig(py::module *m) {
|
||||
using PyClass = EndpointConfig;
|
||||
py::class_<PyClass>(*m, "EndpointConfig")
|
||||
.def(
|
||||
py::init(
|
||||
[](float rule1_min_trailing_silence,
|
||||
float rule2_min_trailing_silence,
|
||||
float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
|
||||
EndpointRule rule1(false, rule1_min_trailing_silence, 0);
|
||||
EndpointRule rule2(true, rule2_min_trailing_silence, 0);
|
||||
EndpointRule rule3(false, 0, rule3_min_utterance_length);
|
||||
|
||||
return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
|
||||
}),
|
||||
py::arg("rule1_min_trailing_silence"),
|
||||
py::arg("rule2_min_trailing_silence"),
|
||||
py::arg("rule3_min_utterance_length"))
|
||||
.def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
|
||||
const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
|
||||
auto ans = std::make_unique<PyClass>();
|
||||
ans->rule1 = rule1;
|
||||
ans->rule2 = rule2;
|
||||
ans->rule3 = rule3;
|
||||
return ans;
|
||||
}),
|
||||
py::arg("rule1") = EndpointRule(false, 2.4, 0),
|
||||
py::arg("rule2") = EndpointRule(true, 1.2, 0),
|
||||
py::arg("rule3") = EndpointRule(false, 0, 20),
|
||||
kEndpointConfigInitDoc)
|
||||
.def("__str__",
|
||||
[](const PyClass &self) -> std::string { return self.ToString(); })
|
||||
.def_readwrite("rule1", &PyClass::rule1)
|
||||
.def_readwrite("rule2", &PyClass::rule2)
|
||||
.def_readwrite("rule3", &PyClass::rule3);
|
||||
}
|
||||
|
||||
void PybindEndpoint(py::module *m) {
|
||||
PybindEndpointRule(m);
|
||||
PybindEndpointConfig(m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/endpoint.h
Normal file
16
sherpa-onnx/python/csrc/endpoint.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/csrc/endpoint.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindEndpoint(py::module *m);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_
|
||||
@@ -21,11 +21,15 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const std::string &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"))
|
||||
const OnlineTransducerModelConfig &, const std::string &,
|
||||
const EndpointConfig &, bool>(),
|
||||
py::arg("feat_config"), py::arg("model_config"), py::arg("tokens"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -43,7 +47,9 @@ void PybindOnlineRecognizer(py::module *m) {
|
||||
[](PyClass &self, std::vector<OnlineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
})
|
||||
.def("get_result", &PyClass::GetResult);
|
||||
.def("get_result", &PyClass::GetResult)
|
||||
.def("is_endpoint", &PyClass::IsEndpoint)
|
||||
.def("reset", &PyClass::Reset);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
@@ -16,6 +17,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineTransducerModelConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
|
||||
@@ -2,12 +2,13 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from _sherpa_onnx import (
|
||||
EndpointConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizer as _Recognizer,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineStream,
|
||||
OnlineTransducerModelConfig,
|
||||
FeatureExtractorConfig,
|
||||
OnlineRecognizerConfig,
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
|
||||
|
||||
def _assert_file_exists(f: str):
|
||||
@@ -26,6 +27,10 @@ class OnlineRecognizer(object):
|
||||
num_threads: int = 4,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
enable_endpoint_detection: bool = False,
|
||||
rule1_min_trailing_silence: int = 2.4,
|
||||
rule2_min_trailing_silence: int = 1.2,
|
||||
rule3_min_utterance_length: int = 20,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -52,6 +57,22 @@ class OnlineRecognizer(object):
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
enable_endpoint_detection:
|
||||
True to enable endpoint detection. False to disable endpoint
|
||||
detection.
|
||||
rule1_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If the duration
|
||||
of trailing silence in seconds is larger than this value, we assume
|
||||
an endpoint is detected.
|
||||
rule2_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If we have decoded
|
||||
something that is nonsilence and if the duration of trailing silence
|
||||
in seconds is larger than this value, we assume an endpoint is
|
||||
detected.
|
||||
rule3_min_utterance_length:
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -72,10 +93,18 @@ class OnlineRecognizer(object):
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
rule1_min_trailing_silence=rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
tokens=tokens,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
@@ -93,4 +122,10 @@ class OnlineRecognizer(object):
|
||||
return self.recognizer.is_ready(s)
|
||||
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text
|
||||
return self.recognizer.get_result(s).text.strip()
|
||||
|
||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_endpoint(s)
|
||||
|
||||
def reset(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.reset(s)
|
||||
|
||||
Reference in New Issue
Block a user