diff --git a/CMakeLists.txt b/CMakeLists.txt index 69c417d6..845d300e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.4.4") +set(SHERPA_ONNX_VERSION "1.4.5") # Disable warning about # diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c index 07672fb6..963b40c4 100644 --- a/c-api-examples/decode-file-c-api.c +++ b/c-api-examples/decode-file-c-api.c @@ -5,59 +5,54 @@ // This file shows how to use sherpa-onnx C API // to decode a file. -#include "cargs.h" #include #include #include +#include "cargs.h" #include "sherpa-onnx/c-api/c-api.h" static struct cag_option options[] = { - { - .identifier = 't', - .access_letters = NULL, - .access_name = "tokens", - .value_name = "tokens", - .description = "Tokens file" - }, { - .identifier = 'e', - .access_letters = NULL, - .access_name = "encoder", - .value_name = "encoder", - .description = "Encoder ONNX file" - }, { - .identifier = 'd', - .access_letters = NULL, - .access_name = "decoder", - .value_name = "decoder", - .description = "Decoder ONNX file" - }, { - .identifier = 'j', - .access_letters = NULL, - .access_name = "joiner", - .value_name = "joiner", - .description = "Joiner ONNX file" - }, { - .identifier = 'n', - .access_letters = NULL, - .access_name = "num-threads", - .value_name = "num-threads", - .description = "Number of threads" - }, { - .identifier = 'p', - .access_letters = NULL, - .access_name = "provider", - .value_name = "provider", - .description = "Provider: cpu (default), cuda, coreml" - }, { - .identifier = 'm', - .access_letters = NULL, - .access_name = "decoding-method", - .value_name = "decoding-method", - .description = - "Decoding method: greedy_search (default), modified_beam_search" - } -}; + {.identifier = 'h', + .access_letters = "h", + .access_name = "help", + .description = "Show help"}, + {.identifier = 't', + .access_letters = NULL, + .access_name = "tokens", + .value_name = "tokens", + .description = "Tokens file"}, + {.identifier = 'e', + .access_letters = NULL, + .access_name = "encoder", + .value_name = "encoder", + .description = "Encoder ONNX file"}, + {.identifier = 'd', + .access_letters = NULL, + .access_name = "decoder", + .value_name = "decoder", + .description = "Decoder ONNX file"}, + {.identifier = 'j', + .access_letters = NULL, + .access_name = "joiner", + .value_name = "joiner", + .description = "Joiner ONNX file"}, + {.identifier = 'n', + .access_letters = NULL, + .access_name = "num-threads", + .value_name = "num-threads", + .description = "Number of threads"}, + {.identifier = 'p', + .access_letters = NULL, + .access_name = "provider", + .value_name = "provider", + .description = "Provider: cpu (default), cuda, coreml"}, + {.identifier = 'm', + .access_letters = NULL, + .access_name = "decoding-method", + .value_name = "decoding-method", + .description = + "Decoding method: greedy_search (default), modified_beam_search"}}; const char *kUsage = "\n" @@ -67,6 +62,7 @@ const char *kUsage = " --encoder=/path/to/encoder.onnx \\\n" " --decoder=/path/to/decoder.onnx \\\n" " --joiner=/path/to/joiner.onnx \\\n" + " --provider=cpu \\\n" " /path/to/foo.wav\n" "\n\n" "Default num_threads is 1.\n" @@ -77,6 +73,11 @@ const char *kUsage = "for a list of pre-trained models to download.\n"; int32_t main(int32_t argc, char *argv[]) { + if (argc < 6) { + fprintf(stderr, "%s\n", kUsage); + exit(0); + } + SherpaOnnxOnlineRecognizerConfig config; config.model_config.debug = 0; @@ -105,19 +106,38 @@ int32_t main(int32_t argc, char *argv[]) { identifier = cag_option_get(&context); value = cag_option_get_value(&context); switch (identifier) { - case 't': config.model_config.tokens = value; break; - case 'e': config.model_config.encoder = value; break; - case 'd': config.model_config.decoder = value; break; - case 'j': config.model_config.joiner = value; break; - case 'n': config.model_config.num_threads = atoi(value); break; - case 'p': config.model_config.provider = value; break; - case 'm': config.decoding_method = value; break; - default: + case 't': + config.model_config.tokens = value; + break; + case 'e': + config.model_config.encoder = value; + break; + case 'd': + config.model_config.decoder = value; + break; + case 'j': + config.model_config.joiner = value; + break; + case 'n': + config.model_config.num_threads = atoi(value); + break; + case 'p': + config.model_config.provider = value; + break; + case 'm': + config.decoding_method = value; + break; + case 'h': { + fprintf(stderr, "%s\n", kUsage); + exit(0); + break; + } + default: // do nothing as config already have valid default values break; } } - + SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); diff --git a/dotnet-examples/online-decode-files/Program.cs b/dotnet-examples/online-decode-files/Program.cs index f37ad2bc..8fa41920 100644 --- a/dotnet-examples/online-decode-files/Program.cs +++ b/dotnet-examples/online-decode-files/Program.cs @@ -20,6 +20,9 @@ class OnlineDecodeFiles [Option(Required = true, HelpText = "Path to tokens.txt")] public string Tokens { get; set; } + [Option(Required = false, Default = "cpu", HelpText = "Provider, e.g., cpu, coreml")] + public string Provider { get; set; } + [Option(Required = true, HelpText = "Path to encoder.onnx")] public string Encoder { get; set; } @@ -124,6 +127,7 @@ to download pre-trained streaming models. config.TransducerModelConfig.Decoder = options.Decoder; config.TransducerModelConfig.Joiner = options.Joiner; config.TransducerModelConfig.Tokens = options.Tokens; + config.TransducerModelConfig.Provider = options.Provider; config.TransducerModelConfig.NumThreads = options.NumThreads; config.TransducerModelConfig.Debug = options.Debug ? 1 : 0; diff --git a/scripts/dotnet/online.cs b/scripts/dotnet/online.cs index a51b72e3..d86d1520 100644 --- a/scripts/dotnet/online.cs +++ b/scripts/dotnet/online.cs @@ -23,6 +23,7 @@ namespace SherpaOnnx Joiner = ""; Tokens = ""; NumThreads = 1; + Provider = "cpu"; Debug = 0; } [MarshalAs(UnmanagedType.LPStr)] @@ -40,6 +41,9 @@ namespace SherpaOnnx /// Number of threads used to run the neural network model public int NumThreads; + [MarshalAs(UnmanagedType.LPStr)] + public string Provider; + /// true to print debug information of the model public int Debug; }