Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -19,14 +19,16 @@ const char *kUsage =
|
||||
" /path/to/encoder.onnx \\\n"
|
||||
" /path/to/decoder.onnx \\\n"
|
||||
" /path/to/joiner.onnx \\\n"
|
||||
" /path/to/foo.wav [num_threads]\n"
|
||||
" /path/to/foo.wav [num_threads [decoding_method]]\n"
|
||||
"\n\n"
|
||||
"Default num_threads is 1.\n"
|
||||
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
|
||||
"Please refer to \n"
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
|
||||
"for a list of pre-trained models to download.\n";
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 7) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
fprintf(stderr, "%s\n", kUsage);
|
||||
return -1;
|
||||
}
|
||||
@@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
config.model_config.decoder = argv[3];
|
||||
config.model_config.joiner = argv[4];
|
||||
|
||||
int32_t num_threads = 4;
|
||||
int32_t num_threads = 1;
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
num_threads = atoi(argv[6]);
|
||||
}
|
||||
config.model_config.num_threads = num_threads;
|
||||
config.model_config.debug = 0;
|
||||
|
||||
config.decoding_method = "greedy_search";
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
|
||||
config.max_active_paths = 4;
|
||||
|
||||
config.feat_config.sample_rate = 16000;
|
||||
config.feat_config.feature_dim = 80;
|
||||
|
||||
@@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
|
||||
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
|
||||
|
||||
SherpaOnnxDisplay *display = CreateDisplay(50);
|
||||
int32_t segment_id = 0;
|
||||
|
||||
const char *wav_filename = argv[5];
|
||||
FILE *fp = fopen(wav_filename, "rb");
|
||||
if (!fp) {
|
||||
@@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
|
||||
SherpaOnnxOnlineRecognizerResult *r =
|
||||
GetOnlineStreamResult(recognizer, stream);
|
||||
|
||||
if (strlen(r->text)) {
|
||||
fprintf(stderr, "%s\n", r->text);
|
||||
SherpaOnnxPrint(display, segment_id, r->text);
|
||||
}
|
||||
|
||||
if (IsEndpoint(recognizer, stream)) {
|
||||
if (strlen(r->text)) {
|
||||
++segment_id;
|
||||
}
|
||||
Reset(recognizer, stream);
|
||||
}
|
||||
|
||||
DestroyOnlineRecognizerResult(r);
|
||||
}
|
||||
}
|
||||
@@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
|
||||
SherpaOnnxOnlineRecognizerResult *r =
|
||||
GetOnlineStreamResult(recognizer, stream);
|
||||
|
||||
if (strlen(r->text)) {
|
||||
fprintf(stderr, "%s\n", r->text);
|
||||
SherpaOnnxPrint(display, segment_id, r->text);
|
||||
}
|
||||
|
||||
DestroyOnlineRecognizerResult(r);
|
||||
|
||||
DestroyDisplay(display);
|
||||
DestoryOnlineStream(stream);
|
||||
DestroyOnlineRecognizer(recognizer);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user