Added provider option to sherpa-onnx and decode-file-c-api (#162)

This commit is contained in:
Jingzhao Ou
2023-06-02 13:57:48 -07:00
committed by GitHub
parent 5e2dc5ceea
commit 0ed501b8f1
9 changed files with 265 additions and 171 deletions

View File

@@ -1,3 +1,5 @@
include(cargs)
include_directories(${CMAKE_SOURCE_DIR})
add_executable(decode-file-c-api decode-file-c-api.c)
target_link_libraries(decode-file-c-api sherpa-onnx-c-api)
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)

View File

@@ -5,50 +5,85 @@
// This file shows how to use sherpa-onnx C API
// to decode a file.
#include "cargs.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "sherpa-onnx/c-api/c-api.h"
static struct cag_option options[] = {
{
.identifier = 't',
.access_letters = NULL,
.access_name = "tokens",
.value_name = "tokens",
.description = "Tokens file"
}, {
.identifier = 'e',
.access_letters = NULL,
.access_name = "encoder",
.value_name = "encoder",
.description = "Encoder ONNX file"
}, {
.identifier = 'd',
.access_letters = NULL,
.access_name = "decoder",
.value_name = "decoder",
.description = "Decoder ONNX file"
}, {
.identifier = 'j',
.access_letters = NULL,
.access_name = "joiner",
.value_name = "joiner",
.description = "Joiner ONNX file"
}, {
.identifier = 'n',
.access_letters = NULL,
.access_name = "num-threads",
.value_name = "num-threads",
.description = "Number of threads"
}, {
.identifier = 'p',
.access_letters = NULL,
.access_name = "provider",
.value_name = "provider",
.description = "Provider: cpu (default), cuda, coreml"
}, {
.identifier = 'm',
.access_letters = NULL,
.access_name = "decoding-method",
.value_name = "decoding-method",
.description =
"Decoding method: greedy_search (default), modified_beam_search"
}
};
const char *kUsage =
"\n"
"Usage:\n "
" ./bin/decode-file-c-api \\\n"
" /path/to/tokens.txt \\\n"
" /path/to/encoder.onnx \\\n"
" /path/to/decoder.onnx \\\n"
" /path/to/joiner.onnx \\\n"
" /path/to/foo.wav [num_threads [decoding_method]]\n"
" --tokens=/path/to/tokens.txt \\\n"
" --encoder=/path/to/encoder.onnx \\\n"
" --decoder=/path/to/decoder.onnx \\\n"
" --joiner=/path/to/joiner.onnx \\\n"
" /path/to/foo.wav\n"
"\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Valid provider: cpu (default), cuda, coreml\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
int32_t main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
SherpaOnnxOnlineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.encoder = argv[2];
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
int32_t num_threads = 1;
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.model_config.num_threads = 1;
config.model_config.provider = "cpu";
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
@@ -60,13 +95,36 @@ int32_t main(int32_t argc, char *argv[]) {
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;
cag_option_context context;
char identifier;
const char *value;
cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv);
while (cag_option_fetch(&context)) {
identifier = cag_option_get(&context);
value = cag_option_get_value(&context);
switch (identifier) {
case 't': config.model_config.tokens = value; break;
case 'e': config.model_config.encoder = value; break;
case 'd': config.model_config.decoder = value; break;
case 'j': config.model_config.joiner = value; break;
case 'n': config.model_config.num_threads = atoi(value); break;
case 'p': config.model_config.provider = value; break;
case 'm': config.decoding_method = value; break;
default:
// do nothing as config already have valid default values
break;
}
}
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
const char *wav_filename = argv[5];
const char *wav_filename = argv[context.index];
FILE *fp = fopen(wav_filename, "rb");
if (!fp) {
fprintf(stderr, "Failed to open %s\n", wav_filename);