Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -19,14 +19,16 @@ const char *kUsage =
" /path/to/encoder.onnx \\\n"
" /path/to/decoder.onnx \\\n"
" /path/to/joiner.onnx \\\n"
" /path/to/foo.wav [num_threads]\n"
" /path/to/foo.wav [num_threads [decoding_method]]\n"
"\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
int32_t main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
@@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
int32_t num_threads = 4;
int32_t num_threads = 1;
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
@@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
const char *wav_filename = argv[5];
FILE *fp = fopen(wav_filename, "rb");
if (!fp) {
@@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
SherpaOnnxPrint(display, segment_id, r->text);
}
if (IsEndpoint(recognizer, stream)) {
if (strlen(r->text)) {
++segment_id;
}
Reset(recognizer, stream);
}
DestroyOnlineRecognizerResult(r);
}
}
@@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
SherpaOnnxPrint(display, segment_id, r->text);
}
DestroyOnlineRecognizerResult(r);
DestroyDisplay(display);
DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
fprintf(stderr, "\n");
return 0;
}