diff --git a/c-api-examples/Makefile b/c-api-examples/Makefile index 14e38ad4..18ddda27 100644 --- a/c-api-examples/Makefile +++ b/c-api-examples/Makefile @@ -1,10 +1,10 @@ CUR_DIR :=$(shell pwd) -CFLAGS := -I ../ +CFLAGS := -I ../ -I ../build/_deps/cargs-src/include/ LDFLAGS := -L ../build/lib LDFLAGS += -L ../build/_deps/onnxruntime-src/lib -LDFLAGS += -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core +LDFLAGS += -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs LDFLAGS += -Wl,-rpath,${CUR_DIR}/../build/lib LDFLAGS += -Wl,-rpath,${CUR_DIR}/../build/_deps/onnxruntime-src/lib diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c index afcc3e4b..792b12d6 100644 --- a/c-api-examples/decode-file-c-api.c +++ b/c-api-examples/decode-file-c-api.c @@ -52,7 +52,21 @@ static struct cag_option options[] = { .access_name = "decoding-method", .value_name = "decoding-method", .description = - "Decoding method: greedy_search (default), modified_beam_search"}}; + "Decoding method: greedy_search (default), modified_beam_search"}, + {.identifier = 'f', + .access_letters = NULL, + .access_name = "hotwords-file", + .value_name = "hotwords-file", + .description = "The file containing hotwords, one words/phrases per line, " + "and for each phrase the bpe/cjkchar are separated by a " + "space. For example: ▁HE LL O ▁WORLD, 你 好 世 界"}, + {.identifier = 's', + .access_letters = NULL, + .access_name = "hotwords-score", + .value_name = "hotwords-score", + .description = "The bonus score for each token in hotwords. Used only " + "when decoding_method is modified_beam_search"}, +}; const char *kUsage = "\n" @@ -130,6 +144,12 @@ int32_t main(int32_t argc, char *argv[]) { case 'm': config.decoding_method = value; break; + case 'f': + config.hotwords_file = value; + break; + case 's': + config.hotwords_score = atof(value); + break; case 'h': { fprintf(stderr, "%s\n", kUsage); exit(0); diff --git a/c-api-examples/run.sh b/c-api-examples/run.sh index ac986cf9..45ebaaa1 100755 --- a/c-api-examples/run.sh +++ b/c-api-examples/run.sh @@ -27,8 +27,22 @@ if [ ! -f ./decode-file-c-api ]; then fi ./decode-file-c-api \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav + +# Run with hotwords + +echo "礼 拜 二" > hotwords.txt + +./decode-file-c-api \ + --tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --hotwords-file=hotwords.txt \ + --hotwords-score=1.5 \ + --decoding-method=modified_beam_search \ ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 7b3c55b1..07933bee 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -80,6 +80,10 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.endpoint_config.rule3.min_utterance_length = SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20); + recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_score = + SHERPA_ONNX_OR(config->hotwords_score, 1.5); + if (config->model_config.debug) { fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); } @@ -297,6 +301,10 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); + recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_score = + SHERPA_ONNX_OR(config->hotwords_score, 1.5); + if (config->model_config.debug) { fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 1a5c4dbf..5e514d3d 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -111,6 +111,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { /// this value. /// Used only when enable_endpoint is not 0. float rule3_min_utterance_length; + + /// Path to the hotwords. + const char *hotwords_file; + + /// Bonus score for each token in hotwords. + float hotwords_score; } SherpaOnnxOnlineRecognizerConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { @@ -335,6 +341,12 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { const char *decoding_method; int32_t max_active_paths; + + /// Path to the hotwords. + const char *hotwords_file; + + /// Bonus score for each token in hotwords. + float hotwords_score; } SherpaOnnxOfflineRecognizerConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizer