From 7f72c13d9aa0f6bea5337b9f5615a4666b51a007 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 3 Mar 2023 12:10:59 +0800 Subject: [PATCH] Code refactoring (#74) * Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues --- .gitignore | 1 + c-api-examples/decode-file-c-api.c | 34 +- ffmpeg-examples/run.sh | 17 +- ffmpeg-examples/sherpa-onnx-ffmpeg.c | 609 +++++++++--------- python-api-examples/decode-file.py | 21 +- ...from-microphone-with-endpoint-detection.py | 17 +- .../speech-recognition-from-microphone.py | 13 +- sherpa-onnx/c-api/c-api.cc | 20 + sherpa-onnx/c-api/c-api.h | 19 + sherpa-onnx/csrc/CMakeLists.txt | 3 +- sherpa-onnx/csrc/display.h | 15 +- sherpa-onnx/csrc/features.cc | 7 +- sherpa-onnx/csrc/features.h | 1 + sherpa-onnx/csrc/hypothesis.h | 4 +- sherpa-onnx/csrc/online-recognizer.cc | 20 +- sherpa-onnx/csrc/online-recognizer.h | 13 +- sherpa-onnx/csrc/online-stream.cc | 14 +- sherpa-onnx/csrc/online-stream.h | 2 +- sherpa-onnx/csrc/online-transducer-decoder.cc | 60 ++ sherpa-onnx/csrc/online-transducer-decoder.h | 19 + ...online-transducer-greedy-search-decoder.cc | 42 +- ...transducer-modified-beam-search-decoder.cc | 46 +- ...-transducer-modified-beam-search-decoder.h | 2 + sherpa-onnx/csrc/sherpa-onnx-alsa.cc | 12 +- sherpa-onnx/csrc/sherpa-onnx-microphone.cc | 12 +- sherpa-onnx/csrc/sherpa-onnx.cc | 16 +- sherpa-onnx/python/csrc/CMakeLists.txt | 7 +- sherpa-onnx/python/csrc/display.cc | 18 + sherpa-onnx/python/csrc/display.h | 16 + sherpa-onnx/python/csrc/features.cc | 6 +- sherpa-onnx/python/csrc/online-recognizer.cc | 7 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 3 + sherpa-onnx/python/sherpa_onnx/__init__.py | 8 +- .../python/sherpa_onnx/online_recognizer.py | 14 + 34 files changed, 744 insertions(+), 374 deletions(-) create mode 100644 sherpa-onnx/csrc/online-transducer-decoder.cc create mode 100644 sherpa-onnx/python/csrc/display.cc create mode 100644 sherpa-onnx/python/csrc/display.h diff --git a/.gitignore b/.gitignore index bded3d9a..43e3a0af 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,5 @@ tokens.txt *.onnx log.txt tags +run-decode-file-python.sh android/SherpaOnnx/app/src/main/assets/ diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c index f7a7d2e1..0b67e1d5 100644 --- a/c-api-examples/decode-file-c-api.c +++ b/c-api-examples/decode-file-c-api.c @@ -19,14 +19,16 @@ const char *kUsage = " /path/to/encoder.onnx \\\n" " /path/to/decoder.onnx \\\n" " /path/to/joiner.onnx \\\n" - " /path/to/foo.wav [num_threads]\n" + " /path/to/foo.wav [num_threads [decoding_method]]\n" "\n\n" + "Default num_threads is 1.\n" + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" "Please refer to \n" "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" "for a list of pre-trained models to download.\n"; int32_t main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 7) { + if (argc < 6 || argc > 8) { fprintf(stderr, "%s\n", kUsage); return -1; } @@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) { config.model_config.decoder = argv[3]; config.model_config.joiner = argv[4]; - int32_t num_threads = 4; + int32_t num_threads = 1; if (argc == 7 && atoi(argv[6]) > 0) { num_threads = atoi(argv[6]); } config.model_config.num_threads = num_threads; config.model_config.debug = 0; + config.decoding_method = "greedy_search"; + if (argc == 8) { + config.decoding_method = argv[7]; + } + + config.max_active_paths = 4; + config.feat_config.sample_rate = 16000; config.feat_config.feature_dim = 80; @@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) { SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); + SherpaOnnxDisplay *display = CreateDisplay(50); + int32_t segment_id = 0; + const char *wav_filename = argv[5]; FILE *fp = fopen(wav_filename, "rb"); if (!fp) { @@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) { SherpaOnnxOnlineRecognizerResult *r = GetOnlineStreamResult(recognizer, stream); + if (strlen(r->text)) { - fprintf(stderr, "%s\n", r->text); + SherpaOnnxPrint(display, segment_id, r->text); } + + if (IsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + Reset(recognizer, stream); + } + DestroyOnlineRecognizerResult(r); } } @@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) { SherpaOnnxOnlineRecognizerResult *r = GetOnlineStreamResult(recognizer, stream); + if (strlen(r->text)) { - fprintf(stderr, "%s\n", r->text); + SherpaOnnxPrint(display, segment_id, r->text); } DestroyOnlineRecognizerResult(r); + DestroyDisplay(display); DestoryOnlineStream(stream); DestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); return 0; } diff --git a/ffmpeg-examples/run.sh b/ffmpeg-examples/run.sh index 7e07747d..a0651e00 100755 --- a/ffmpeg-examples/run.sh +++ b/ffmpeg-examples/run.sh @@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then make fi -../ffmpeg-examples/sherpa-onnx-ffmpeg \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ - ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav +for method in greedy_search modified_beam_search; do + echo "test method: $method" + ../ffmpeg-examples/sherpa-onnx-ffmpeg \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ + 2 \ + $method +done echo "Decoding a URL" diff --git a/ffmpeg-examples/sherpa-onnx-ffmpeg.c b/ffmpeg-examples/sherpa-onnx-ffmpeg.c index 3a8ce271..96d5f35c 100644 --- a/ffmpeg-examples/sherpa-onnx-ffmpeg.c +++ b/ffmpeg-examples/sherpa-onnx-ffmpeg.c @@ -7,7 +7,6 @@ #include "sherpa-onnx/c-api/c-api.h" - /* * Copyright (c) 2010 Nicolas George * Copyright (c) 2011 Stefano Sabatini @@ -43,14 +42,15 @@ #include extern "C" { #include -#include #include #include +#include #include #include } -static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; +static const char *filter_descr = + "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; static AVFormatContext *fmt_ctx; static AVCodecContext *dec_ctx; @@ -59,308 +59,172 @@ AVFilterContext *buffersrc_ctx; AVFilterGraph *filter_graph; static int audio_stream_index = -1; -static int open_input_file(const char *filename) -{ - const AVCodec *dec; - int ret; +static int open_input_file(const char *filename) { + const AVCodec *dec; + int ret; - if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename); - return ret; - } + if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename); + return ret; + } - if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n"); - return ret; - } + if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n"); + return ret; + } - /* select the audio stream */ - ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n"); - return ret; - } - audio_stream_index = ret; + /* select the audio stream */ + ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, + "Cannot find an audio stream in the input file\n"); + return ret; + } + audio_stream_index = ret; - /* create decoding context */ - dec_ctx = avcodec_alloc_context3(dec); - if (!dec_ctx) - return AVERROR(ENOMEM); - avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar); + /* create decoding context */ + dec_ctx = avcodec_alloc_context3(dec); + if (!dec_ctx) return AVERROR(ENOMEM); + avcodec_parameters_to_context(dec_ctx, + fmt_ctx->streams[audio_stream_index]->codecpar); - /* init the audio decoder */ - if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n"); - return ret; - } + /* init the audio decoder */ + if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n"); + return ret; + } - return 0; + return 0; } -static int init_filters(const char *filters_descr) -{ - char args[512]; - int ret = 0; - const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); - const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); - AVFilterInOut *outputs = avfilter_inout_alloc(); - AVFilterInOut *inputs = avfilter_inout_alloc(); - static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE }; - static const int out_sample_rates[] = { 16000, -1 }; - const AVFilterLink *outlink; - AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; +static int init_filters(const char *filters_descr) { + char args[512]; + int ret = 0; + const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); + const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); + AVFilterInOut *outputs = avfilter_inout_alloc(); + AVFilterInOut *inputs = avfilter_inout_alloc(); + static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16, + AV_SAMPLE_FMT_NONE}; + static const int out_sample_rates[] = {16000, -1}; + const AVFilterLink *outlink; + AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; - filter_graph = avfilter_graph_alloc(); - if (!outputs || !inputs || !filter_graph) { - ret = AVERROR(ENOMEM); - goto end; - } + filter_graph = avfilter_graph_alloc(); + if (!outputs || !inputs || !filter_graph) { + ret = AVERROR(ENOMEM); + goto end; + } - /* buffer audio source: the decoded frames from the decoder will be inserted here. */ - if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) - av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels); - ret = snprintf(args, sizeof(args), - "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", - time_base.num, time_base.den, dec_ctx->sample_rate, - av_get_sample_fmt_name(dec_ctx->sample_fmt)); - av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret); - ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", - args, NULL, filter_graph); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); - goto end; - } + /* buffer audio source: the decoded frames from the decoder will be inserted + * here. */ + if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) + av_channel_layout_default(&dec_ctx->ch_layout, + dec_ctx->ch_layout.nb_channels); + ret = snprintf(args, sizeof(args), + "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", + time_base.num, time_base.den, dec_ctx->sample_rate, + av_get_sample_fmt_name(dec_ctx->sample_fmt)); + av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, + sizeof(args) - ret); + ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args, + NULL, filter_graph); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); + goto end; + } - /* buffer audio sink: to terminate the filter chain. */ - ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", - NULL, NULL, filter_graph); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); - goto end; - } + /* buffer audio sink: to terminate the filter chain. */ + ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL, + NULL, filter_graph); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); + goto end; + } - ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1, - AV_OPT_SEARCH_CHILDREN); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n"); - goto end; - } + ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1, + AV_OPT_SEARCH_CHILDREN); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n"); + goto end; + } - ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono", - AV_OPT_SEARCH_CHILDREN); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); - goto end; - } + ret = + av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); + goto end; + } - ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1, - AV_OPT_SEARCH_CHILDREN); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); - goto end; - } + ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, + -1, AV_OPT_SEARCH_CHILDREN); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); + goto end; + } - /* - * Set the endpoints for the filter graph. The filter_graph will - * be linked to the graph described by filters_descr. - */ + /* + * Set the endpoints for the filter graph. The filter_graph will + * be linked to the graph described by filters_descr. + */ - /* - * The buffer source output must be connected to the input pad of - * the first filter described by filters_descr; since the first - * filter input label is not specified, it is set to "in" by - * default. - */ - outputs->name = av_strdup("in"); - outputs->filter_ctx = buffersrc_ctx; - outputs->pad_idx = 0; - outputs->next = NULL; + /* + * The buffer source output must be connected to the input pad of + * the first filter described by filters_descr; since the first + * filter input label is not specified, it is set to "in" by + * default. + */ + outputs->name = av_strdup("in"); + outputs->filter_ctx = buffersrc_ctx; + outputs->pad_idx = 0; + outputs->next = NULL; - /* - * The buffer sink input must be connected to the output pad of - * the last filter described by filters_descr; since the last - * filter output label is not specified, it is set to "out" by - * default. - */ - inputs->name = av_strdup("out"); - inputs->filter_ctx = buffersink_ctx; - inputs->pad_idx = 0; - inputs->next = NULL; + /* + * The buffer sink input must be connected to the output pad of + * the last filter described by filters_descr; since the last + * filter output label is not specified, it is set to "out" by + * default. + */ + inputs->name = av_strdup("out"); + inputs->filter_ctx = buffersink_ctx; + inputs->pad_idx = 0; + inputs->next = NULL; - if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, - &inputs, &outputs, NULL)) < 0) - goto end; + if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs, + &outputs, NULL)) < 0) + goto end; - if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) - goto end; + if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end; - /* Print summary of the sink buffer - * Note: args buffer is reused to store channel layout string */ - outlink = buffersink_ctx->inputs[0]; - av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); - av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", - (int)outlink->sample_rate, - (char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), - args); + /* Print summary of the sink buffer + * Note: args buffer is reused to store channel layout string */ + outlink = buffersink_ctx->inputs[0]; + av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); + av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", + (int)outlink->sample_rate, + (char *)av_x_if_null( + av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), + args); end: - avfilter_inout_free(&inputs); - avfilter_inout_free(&outputs); + avfilter_inout_free(&inputs); + avfilter_inout_free(&outputs); - return ret; + return ret; } -static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer, - SherpaOnnxOnlineStream* stream) -{ +static void sherpa_decode_frame(const AVFrame *frame, + SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream, + SherpaOnnxDisplay *display, + int32_t *segment_id) { #define N 3200 // 100s. Sample rate is fixed to 16 kHz - static float samples[N]; - static int nb_samples = 0; - const int16_t *p = (int16_t*)frame->data[0]; - - if (frame->nb_samples + nb_samples > N) { - AcceptWaveform(stream, 16000, samples, nb_samples); - while (IsOnlineStreamReady(recognizer, stream)) { - DecodeOnlineStream(recognizer, stream); - } - - - if (IsEndpoint(recognizer, stream)) { - SherpaOnnxOnlineRecognizerResult *r = - GetOnlineStreamResult(recognizer, stream); - if (strlen(r->text)) { - fprintf(stderr, "%s\n", r->text); - } - DestroyOnlineRecognizerResult(r); - - Reset(recognizer, stream); - } - nb_samples = 0; - } - - for (int i = 0; i < frame->nb_samples; i++) { - samples[nb_samples++] = p[i] / 32768.; - } -} - -static inline char *__av_err2str(int errnum) -{ - static char str[AV_ERROR_MAX_STRING_SIZE]; - memset(str, 0, sizeof(str)); - return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum); -} - -int main(int argc, char **argv) -{ - int ret; - int num_threads = 4; - AVPacket *packet = av_packet_alloc(); - AVFrame *frame = av_frame_alloc(); - AVFrame *filt_frame = av_frame_alloc(); - const char *kUsage = - "\n" - "Usage:\n" - " ./sherpa-onnx-ffmpeg \\\n" - " /path/to/tokens.txt \\\n" - " /path/to/encoder.onnx\\\n" - " /path/to/decoder.onnx\\\n" - " /path/to/joiner.onnx\\\n" - " /path/to/foo.wav [num_threads]" - "\n\n" - "Please refer to \n" - "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" - "for a list of pre-trained models to download.\n"; - - - if (!packet || !frame || !filt_frame) { - fprintf(stderr, "Could not allocate frame or packet\n"); - exit(1); - } - - if (argc < 6 || argc > 7) { - fprintf(stderr, "%s\n", kUsage); - return -1; - } - - SherpaOnnxOnlineRecognizerConfig config; - config.model_config.tokens = argv[1]; - config.model_config.encoder = argv[2]; - config.model_config.decoder = argv[3]; - config.model_config.joiner = argv[4]; - - if (argc == 7 && atoi(argv[6]) > 0) { - num_threads = atoi(argv[6]); - } - config.model_config.num_threads = num_threads; - config.model_config.debug = 0; - - config.feat_config.sample_rate = 16000; - config.feat_config.feature_dim = 80; - - config.enable_endpoint = 1; - config.rule1_min_trailing_silence = 2.4; - config.rule2_min_trailing_silence = 1.2; - config.rule3_min_utterance_length = 300; - - SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); - SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); - - if ((ret = open_input_file(argv[5])) < 0) - exit(1); - - if ((ret = init_filters(filter_descr)) < 0) - exit(1); - - /* read all packets */ - while (1) { - if ((ret = av_read_frame(fmt_ctx, packet)) < 0) - break; - - if (packet->stream_index == audio_stream_index) { - ret = avcodec_send_packet(dec_ctx, packet); - if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n"); - break; - } - - while (ret >= 0) { - ret = avcodec_receive_frame(dec_ctx, frame); - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { - break; - } else if (ret < 0) { - av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n"); - exit(1); - } - - if (ret >= 0) { - /* push the audio data from decoded frame into the filtergraph */ - if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) { - av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n"); - break; - } - - /* pull filtered audio from the filtergraph */ - while (1) { - ret = av_buffersink_get_frame(buffersink_ctx, filt_frame); - if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) - break; - if (ret < 0) - exit(1); - sherpa_decode_frame(filt_frame, recognizer, stream); - av_frame_unref(filt_frame); - } - av_frame_unref(frame); - } - } - } - av_packet_unref(packet); - } - - // add some tail padding - float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate - AcceptWaveform(stream, 16000, tail_paddings, 4800); - InputFinished(stream); + static float samples[N]; + static int nb_samples = 0; + const int16_t *p = (int16_t *)frame->data[0]; + if (frame->nb_samples + nb_samples > N) { + AcceptWaveform(stream, 16000, samples, nb_samples); while (IsOnlineStreamReady(recognizer, stream)) { DecodeOnlineStream(recognizer, stream); } @@ -368,25 +232,180 @@ int main(int argc, char **argv) SherpaOnnxOnlineRecognizerResult *r = GetOnlineStreamResult(recognizer, stream); if (strlen(r->text)) { - fprintf(stderr, "%s\n", r->text); + SherpaOnnxPrint(display, *segment_id, r->text); + } + + if (IsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++*segment_id; + } + Reset(recognizer, stream); } DestroyOnlineRecognizerResult(r); + nb_samples = 0; + } - DestoryOnlineStream(stream); - DestroyOnlineRecognizer(recognizer); - - avfilter_graph_free(&filter_graph); - avcodec_free_context(&dec_ctx); - avformat_close_input(&fmt_ctx); - av_packet_free(&packet); - av_frame_free(&frame); - av_frame_free(&filt_frame); - - if (ret < 0 && ret != AVERROR_EOF) { - fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); - exit(1); - } - - return 0; + for (int i = 0; i < frame->nb_samples; i++) { + samples[nb_samples++] = p[i] / 32768.; + } +} + +static inline char *__av_err2str(int errnum) { + static char str[AV_ERROR_MAX_STRING_SIZE]; + memset(str, 0, sizeof(str)); + return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum); +} + +int main(int argc, char **argv) { + int ret; + int num_threads = 1; + AVPacket *packet = av_packet_alloc(); + AVFrame *frame = av_frame_alloc(); + AVFrame *filt_frame = av_frame_alloc(); + const char *kUsage = + "\n" + "Usage:\n" + " ./sherpa-onnx-ffmpeg \\\n" + " /path/to/tokens.txt \\\n" + " /path/to/encoder.onnx\\\n" + " /path/to/decoder.onnx\\\n" + " /path/to/joiner.onnx\\\n" + " /path/to/foo.wav [num_threads [decoding_method]]" + "\n\n" + "Default num_threads is 1.\n" + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" + "Please refer to \n" + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" + "for a list of pre-trained models to download.\n"; + + if (!packet || !frame || !filt_frame) { + fprintf(stderr, "Could not allocate frame or packet\n"); + exit(1); + } + + if (argc < 6 || argc > 8) { + fprintf(stderr, "%s\n", kUsage); + return -1; + } + + SherpaOnnxOnlineRecognizerConfig config; + config.model_config.tokens = argv[1]; + config.model_config.encoder = argv[2]; + config.model_config.decoder = argv[3]; + config.model_config.joiner = argv[4]; + + if (argc == 7 && atoi(argv[6]) > 0) { + num_threads = atoi(argv[6]); + } + + config.model_config.num_threads = num_threads; + config.model_config.debug = 0; + + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + + config.decoding_method = "greedy_search"; + if (argc == 8) { + config.decoding_method = argv[7]; + } + + config.max_active_paths = 4; + + config.enable_endpoint = 1; + config.rule1_min_trailing_silence = 2.4; + config.rule2_min_trailing_silence = 1.2; + config.rule3_min_utterance_length = 300; + + SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); + SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); + SherpaOnnxDisplay *display = CreateDisplay(50); + int32_t segment_id = 0; + + if ((ret = open_input_file(argv[5])) < 0) exit(1); + + if ((ret = init_filters(filter_descr)) < 0) exit(1); + + /* read all packets */ + while (1) { + if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break; + + if (packet->stream_index == audio_stream_index) { + ret = avcodec_send_packet(dec_ctx, packet); + if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, + "Error while sending a packet to the decoder\n"); + break; + } + + while (ret >= 0) { + ret = avcodec_receive_frame(dec_ctx, frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + } else if (ret < 0) { + av_log(NULL, AV_LOG_ERROR, + "Error while receiving a frame from the decoder\n"); + exit(1); + } + + if (ret >= 0) { + /* push the audio data from decoded frame into the filtergraph */ + if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, + AV_BUFFERSRC_FLAG_KEEP_REF) < 0) { + av_log(NULL, AV_LOG_ERROR, + "Error while feeding the audio filtergraph\n"); + break; + } + + /* pull filtered audio from the filtergraph */ + while (1) { + ret = av_buffersink_get_frame(buffersink_ctx, filt_frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break; + if (ret < 0) exit(1); + sherpa_decode_frame(filt_frame, recognizer, stream, display, + &segment_id); + av_frame_unref(filt_frame); + } + av_frame_unref(frame); + } + } + } + av_packet_unref(packet); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + AcceptWaveform(stream, 16000, tail_paddings, 4800); + InputFinished(stream); + + while (IsOnlineStreamReady(recognizer, stream)) { + DecodeOnlineStream(recognizer, stream); + } + + SherpaOnnxOnlineRecognizerResult *r = + GetOnlineStreamResult(recognizer, stream); + if (strlen(r->text)) { + SherpaOnnxPrint(display, segment_id, r->text); + } + + DestroyOnlineRecognizerResult(r); + + DestroyDisplay(display); + DestoryOnlineStream(stream); + DestroyOnlineRecognizer(recognizer); + + avfilter_graph_free(&filter_graph); + avcodec_free_context(&dec_ctx); + avformat_close_input(&fmt_ctx); + av_packet_free(&packet); + av_frame_free(&frame); + av_frame_free(&filt_frame); + + if (ret < 0 && ret != AVERROR_EOF) { + fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); + exit(1); + } + fprintf(stderr, "\n"); + + return 0; } diff --git a/python-api-examples/decode-file.py b/python-api-examples/decode-file.py index 79e846a4..368499f8 100755 --- a/python-api-examples/decode-file.py +++ b/python-api-examples/decode-file.py @@ -53,6 +53,20 @@ def get_args(): help="Path to the joiner model", ) + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( "--wave-filename", type=str, @@ -65,7 +79,6 @@ def get_args(): def main(): sample_rate = 16000 - num_threads = 2 args = get_args() assert_file_exists(args.encoder) @@ -81,9 +94,10 @@ def main(): encoder=args.encoder, decoder=args.decoder, joiner=args.joiner, - num_threads=num_threads, + num_threads=args.num_threads, sample_rate=sample_rate, feature_dim=80, + decoding_method=args.decoding_method, ) with wave.open(args.wave_filename) as f: assert f.getframerate() == sample_rate, f.getframerate() @@ -119,7 +133,8 @@ def main(): end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / duration - print(f"num_threads: {num_threads}") + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") print(f"Wave duration: {duration:.3f} s") print(f"Elapsed time: {elapsed_seconds:.3f} s") print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 93571364..44d22549 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -60,10 +60,10 @@ def get_args(): ) parser.add_argument( - "--wave-filename", + "--decoding-method", type=str, - help="""Path to the wave filename. Must be 16 kHz, - mono with 16-bit samples""", + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", ) return parser.parse_args() @@ -83,17 +83,23 @@ def create_recognizer(): encoder=args.encoder, decoder=args.decoder, joiner=args.joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, enable_endpoint_detection=True, rule1_min_trailing_silence=2.4, rule2_min_trailing_silence=1.2, rule3_min_utterance_length=300, # it essentially disables this rule + decoding_method=args.decoding_method, + max_feature_vectors=100, # 1 second ) return recognizer def main(): - print("Started! Please speak") recognizer = create_recognizer() + print("Started! Please speak") + sample_rate = 16000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" @@ -101,6 +107,7 @@ def main(): last_result = "" segment_id = 0 + display = sherpa_onnx.Display(max_word_per_line=30) with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: while True: samples, _ = s.read(samples_per_read) # a blocking read @@ -115,7 +122,7 @@ def main(): if result and (last_result != result): last_result = result - print(f"{segment_id}: {result}") + display.print(segment_id, result) if is_endpoint: if result: diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index dcc72f51..bca4f2b0 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -59,10 +59,10 @@ def get_args(): ) parser.add_argument( - "--wave-filename", + "--decoding-method", type=str, - help="""Path to the wave filename. Must be 16 kHz, - mono with 16-bit samples""", + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", ) return parser.parse_args() @@ -82,9 +82,11 @@ def create_recognizer(): encoder=args.encoder, decoder=args.decoder, joiner=args.joiner, - num_threads=4, + num_threads=1, sample_rate=16000, feature_dim=80, + decoding_method=args.decoding_method, + max_feature_vectors=100, # 1 second ) return recognizer @@ -96,6 +98,7 @@ def main(): samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" stream = recognizer.create_stream() + display = sherpa_onnx.Display(max_word_per_line=40) with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: while True: samples, _ = s.read(samples_per_read) # a blocking read @@ -106,7 +109,7 @@ def main(): result = recognizer.get_result(stream) if last_result != result: last_result = result - print(result) + display.print(-1, result) if __name__ == "__main__": diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 2c08a91d..79427bce 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/display.h" #include "sherpa-onnx/csrc/online-recognizer.h" struct SherpaOnnxOnlineRecognizer { @@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream { : impl(std::move(p)) {} }; +struct SherpaOnnxDisplay { + std::unique_ptr impl; +}; + SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( const SherpaOnnxOnlineRecognizerConfig *config) { sherpa_onnx::OnlineRecognizerConfig recognizer_config; @@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.model_config.num_threads = config->model_config.num_threads; recognizer_config.model_config.debug = config->model_config.debug; + recognizer_config.decoding_method = config->decoding_method; + recognizer_config.max_active_paths = config->max_active_paths; + recognizer_config.enable_endpoint = config->enable_endpoint; recognizer_config.endpoint_config.rule1.min_trailing_silence = @@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { return recognizer->impl->IsEndpoint(stream->impl.get()); } + +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) { + SherpaOnnxDisplay *ans = new SherpaOnnxDisplay; + ans->impl = std::make_unique(max_word_per_line); + return ans; +} + +void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; } + +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) { + display->impl->Print(idx, s); +} diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 1732d4b3..5f5e07bb 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig { SherpaOnnxFeatureConfig feat_config; SherpaOnnxOnlineTransducerModelConfig model_config; + /// Possible values are: greedy_search, modified_beam_search + const char *decoding_method; + + /// Used only when decoding_method is modified_beam_search + /// Example value: 4 + int32_t max_active_paths; + /// 0 to disable endpoint detection. /// A non-zero value to enable endpoint detection. int32_t enable_endpoint; @@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream); int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream); +// for displaying results on Linux/macOS. +typedef struct SherpaOnnxDisplay SherpaOnnxDisplay; + +/// Create a display object. Must be freed using DestroyDisplay to avoid +/// memory leak. +SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line); + +void DestroyDisplay(SherpaOnnxDisplay *display); + +/// Print the result. +void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b3ff7189..68eeeb9e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -9,10 +9,11 @@ set(sources online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc + online-transducer-decoder.cc online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc - online-transducer-modified-beam-search-decoder.cc online-transducer-model.cc + online-transducer-modified-beam-search-decoder.cc online-zipformer-transducer-model.cc onnx-utils.cc parse-options.cc diff --git a/sherpa-onnx/csrc/display.h b/sherpa-onnx/csrc/display.h index c7bbf292..a366a4f8 100644 --- a/sherpa-onnx/csrc/display.h +++ b/sherpa-onnx/csrc/display.h @@ -12,9 +12,16 @@ namespace sherpa_onnx { class Display { public: + explicit Display(int32_t max_word_per_line = 60) + : max_word_per_line_(max_word_per_line) {} + void Print(int32_t segment_id, const std::string &s) { #ifdef _MSC_VER - fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); + if (segment_id != -1) { + fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); + } else { + fprintf(stderr, "%s\n", s.c_str()); + } return; #endif if (last_segment_ == segment_id) { @@ -27,7 +34,9 @@ class Display { num_previous_lines_ = 0; } - fprintf(stderr, "\r%d:", segment_id); + if (segment_id != -1) { + fprintf(stderr, "\r%d:", segment_id); + } int32_t i = 0; for (size_t n = 0; n < s.size();) { @@ -69,7 +78,7 @@ class Display { void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } private: - int32_t max_word_per_line_ = 60; + int32_t max_word_per_line_; int32_t num_previous_lines_ = 0; int32_t last_segment_ = -1; }; diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index b0defc50..04999924 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const { os << "FeatureExtractorConfig("; os << "sampling_rate=" << sampling_rate << ", "; - os << "feature_dim=" << feature_dim << ")"; + os << "feature_dim=" << feature_dim << ", "; + os << "max_feature_vectors=" << max_feature_vectors << ")"; return os.str(); } @@ -40,9 +41,7 @@ class FeatureExtractor::Impl { opts_.frame_opts.snip_edges = false; opts_.frame_opts.samp_freq = config.sampling_rate; - // cache 100 seconds of feature frames, which is more than enough - // for real needs - opts_.frame_opts.max_feature_vectors = 100 * 100; + opts_.frame_opts.max_feature_vectors = config.max_feature_vectors; opts_.mel_opts.num_bins = config.feature_dim; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 5f0ad967..53183136 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -16,6 +16,7 @@ namespace sherpa_onnx { struct FeatureExtractorConfig { float sampling_rate = 16000; int32_t feature_dim = 80; + int32_t max_feature_vectors = -1; std::string ToString() const; diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 6023af8b..725dba07 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -18,7 +18,7 @@ namespace sherpa_onnx { struct Hypothesis { // The predicted tokens so far. Newly predicated tokens are appended. - std::vector ys; + std::vector ys; // timestamps[i] contains the frame number after subsampling // on which ys[i] is decoded. @@ -30,7 +30,7 @@ struct Hypothesis { int32_t num_trailing_blanks = 0; Hypothesis() = default; - Hypothesis(const std::vector &ys, double log_prob) + Hypothesis(const std::vector &ys, double log_prob) : ys(ys), log_prob(log_prob) {} // If two Hypotheses have the same `Key`, then they contain diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index e85f4574..3ef911ee 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "True to enable endpoint detection. False to disable it."); po->Register("max-active-paths", &max_active_paths, "beam size used in modified beam search."); - po->Register("decoding-mothod", &decoding_method, + po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); } @@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ","; - os << "max_active_paths=" << max_active_paths << ","; + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; + os << "max_active_paths=" << max_active_paths << ", "; os << "decoding_method=\"" << decoding_method << "\")"; return os.str(); @@ -187,16 +187,14 @@ class OnlineRecognizer::Impl { } void Reset(OnlineStream *s) const { - // reset result, neural network model state, and - // the feature extractor state - - // reset result + // we keep the decoder_out + decoder_->UpdateDecoderOut(&s->GetResult()); + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); s->SetResult(decoder_->GetEmptyResult()); + s->GetResult().decoder_out = std::move(decoder_out); - // reset neural network model state - s->SetStates(model_->GetEncoderInitStates()); - - // reset feature extractor + // Note: We only update counters. The underlying audio samples + // are not discarded. s->Reset(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index d03b1795..521e2f1a 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -33,21 +33,26 @@ struct OnlineRecognizerConfig { OnlineTransducerModelConfig model_config; EndpointConfig endpoint_config; bool enable_endpoint = true; - int32_t max_active_paths = 4; - std::string decoding_method = "modified_beam_search"; + std::string decoding_method = "greedy_search"; // now support modified_beam_search and greedy_search + int32_t max_active_paths = 4; // used only for modified_beam_search + OnlineRecognizerConfig() = default; OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, const OnlineTransducerModelConfig &model_config, const EndpointConfig &endpoint_config, - bool enable_endpoint) + bool enable_endpoint, + const std::string &decoding_method, + int32_t max_active_paths) : feat_config(feat_config), model_config(model_config), endpoint_config(endpoint_config), - enable_endpoint(enable_endpoint) {} + enable_endpoint(enable_endpoint), + decoding_method(decoding_method), + max_active_paths(max_active_paths) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 98a0d96a..1ed1588f 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -22,18 +22,21 @@ class OnlineStream::Impl { void InputFinished() { feat_extractor_.InputFinished(); } - int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } + int32_t NumFramesReady() const { + return feat_extractor_.NumFramesReady() - start_frame_index_; + } bool IsLastFrame(int32_t frame) const { return feat_extractor_.IsLastFrame(frame); } std::vector GetFrames(int32_t frame_index, int32_t n) const { - return feat_extractor_.GetFrames(frame_index, n); + return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); } void Reset() { - feat_extractor_.Reset(); + // we don't reset the feature extractor + start_frame_index_ += num_processed_frames_; num_processed_frames_ = 0; } @@ -41,7 +44,7 @@ class OnlineStream::Impl { void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } - const OnlineTransducerDecoderResult &GetResult() const { return result_; } + OnlineTransducerDecoderResult &GetResult() { return result_; } int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } @@ -54,6 +57,7 @@ class OnlineStream::Impl { private: FeatureExtractor feat_extractor_; int32_t num_processed_frames_ = 0; // before subsampling + int32_t start_frame_index_ = 0; // never reset OnlineTransducerDecoderResult result_; std::vector states_; }; @@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { impl_->SetResult(r); } -const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { +OnlineTransducerDecoderResult &OnlineStream::GetResult() { return impl_->GetResult(); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 42bf6d6e..0bba1847 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -63,7 +63,7 @@ class OnlineStream { int32_t &GetNumProcessedFrames(); void SetResult(const OnlineTransducerDecoderResult &r); - const OnlineTransducerDecoderResult &GetResult() const; + OnlineTransducerDecoderResult &GetResult(); void SetStates(std::vector states); std::vector &GetStates(); diff --git a/sherpa-onnx/csrc/online-transducer-decoder.cc b/sherpa-onnx/csrc/online-transducer-decoder.cc new file mode 100644 index 00000000..102b358d --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-decoder.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/csrc/online-transducer-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( + const OnlineTransducerDecoderResult &other) + : OnlineTransducerDecoderResult() { + *this = other; +} + +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( + const OnlineTransducerDecoderResult &other) { + if (this == &other) { + return *this; + } + + tokens = other.tokens; + num_trailing_blanks = other.num_trailing_blanks; + + Ort::AllocatorWithDefaultOptions allocator; + if (other.decoder_out) { + decoder_out = Clone(allocator, &other.decoder_out); + } + + hyps = other.hyps; + + return *this; +} + +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( + OnlineTransducerDecoderResult &&other) + : OnlineTransducerDecoderResult() { + *this = std::move(other); +} + +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( + OnlineTransducerDecoderResult &&other) { + if (this == &other) { + return *this; + } + + tokens = std::move(other.tokens); + num_trailing_blanks = other.num_trailing_blanks; + decoder_out = std::move(other.decoder_out); + hyps = std::move(other.hyps); + + return *this; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index c70afc30..592c206c 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult { /// number of trailing blank frames decoded so far int32_t num_trailing_blanks = 0; + // Cache decoder_out for endpointing + Ort::Value decoder_out; + // used only in modified beam_search Hypotheses hyps; + + OnlineTransducerDecoderResult() + : tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {} + + OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other); + + OnlineTransducerDecoderResult &operator=( + const OnlineTransducerDecoderResult &other); + + OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other); + + OnlineTransducerDecoderResult &operator=( + OnlineTransducerDecoderResult &&other); }; class OnlineTransducerDecoder { @@ -53,6 +69,9 @@ class OnlineTransducerDecoder { */ virtual void Decode(Ort::Value encoder_out, std::vector *result) = 0; + + // used for endpointing. We need to keep decoder_out after reset + virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 5e194f3d..b4b191d8 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -13,6 +13,43 @@ namespace sherpa_onnx { +static void UseCachedDecoderOut( + const std::vector &results, + Ort::Value *decoder_out) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + float *dst = decoder_out->GetTensorMutableData(); + for (const auto &r : results) { + if (r.decoder_out) { + const float *src = r.decoder_out.GetTensorData(); + std::copy(src, src + shape[1], dst); + } + dst += shape[1]; + } +} + +static void UpdateCachedDecoderOut( + OrtAllocator *allocator, const Ort::Value *decoder_out, + std::vector *results) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->GetTensorData(); + for (auto &r : *results) { + if (!r.decoder_out) { + r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out.GetTensorMutableData(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} + OnlineTransducerDecoderResult OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); @@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value decoder_input = model_->BuildDecoderInput(*result); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + UseCachedDecoderOut(*result, &decoder_out); for (int32_t t = 0; t != num_frames; ++t) { Ort::Value cur_encoder_out = @@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( } } if (emitted) { - decoder_input = model_->BuildDecoderInput(*result); + Ort::Value decoder_input = model_->BuildDecoderInput(*result); decoder_out = model_->RunDecoder(std::move(decoder_input)); } } + + UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index eab279c5..2d9825d9 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -13,6 +13,29 @@ namespace sherpa_onnx { +static void UseCachedDecoderOut( + const std::vector &hyps_num_split, + const std::vector &results, + int32_t context_size, Ort::Value *decoder_out) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + + float *dst = decoder_out->GetTensorMutableData(); + + int32_t batch_size = static_cast(results.size()); + for (int32_t i = 0; i != batch_size; ++i) { + int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i]; + if (num_hyps > 1 || !results[i].decoder_out) { + dst += num_hyps * shape[1]; + continue; + } + + const float *src = results[i].decoder_out.GetTensorData(); + std::copy(src, src + shape[1], dst); + dst += shape[1]; + } +} + static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split) { std::vector cur_encoder_out_shape = @@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, blank_id); Hypotheses blank_hyp({{blanks, 0}}); r.hyps = std::move(blank_hyp); return r; @@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( Ort::Value decoder_input = model_->BuildDecoderInput(prev); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + if (t == 0) { + UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(), + &decoder_out); + } Ort::Value cur_encoder_out = GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); @@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( } for (int32_t b = 0; b != batch_size; ++b) { - (*result)[b].hyps = std::move(cur[b]); + auto &hyps = cur[b]; + auto best_hyp = hyps.GetMostProbable(true); + + (*result)[b].hyps = std::move(hyps); + (*result)[b].tokens = std::move(best_hyp.ys); + (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; } } +void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut( + OnlineTransducerDecoderResult *result) { + if (result->tokens.size() == model_->ContextSize()) { + result->decoder_out = Ort::Value{nullptr}; + return; + } + Ort::Value decoder_input = model_->BuildDecoderInput({*result}); + result->decoder_out = model_->RunDecoder(std::move(decoder_input)); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index f1443539..86df4d72 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder void Decode(Ort::Value encoder_out, std::vector *result) override; + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; + private: OnlineTransducerModel *model_; // Not owned int32_t max_active_paths_; diff --git a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc index 0d57b6c6..730f7618 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc @@ -21,7 +21,7 @@ static void Handler(int sig) { } int main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 7) { + if (argc < 6 || argc > 8) { const char *usage = R"usage( Usage: ./bin/sherpa-onnx-alsa \ @@ -30,7 +30,10 @@ Usage: /path/to/decoder.onnx \ /path/to/joiner.onnx \ device_name \ - [num_threads] + [num_threads [decoding_method]] + +Default value for num_threads is 2. +Valid values for decoding_method: greedy_search (default), modified_beam_search. Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html @@ -79,6 +82,11 @@ as the device_name. config.model_config.num_threads = atoi(argv[6]); } + if (argc == 8) { + config.decoding_method = argv[7]; + } + config.max_active_paths = 4; + config.enable_endpoint = true; config.endpoint_config.rule1.min_trailing_silence = 2.4; diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc index 2150d48b..57c448e9 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc @@ -36,7 +36,7 @@ static void Handler(int32_t sig) { } int32_t main(int32_t argc, char *argv[]) { - if (argc < 5 || argc > 6) { + if (argc < 5 || argc > 7) { const char *usage = R"usage( Usage: ./bin/sherpa-onnx-microphone \ @@ -44,7 +44,10 @@ Usage: /path/to/encoder.onnx\ /path/to/decoder.onnx\ /path/to/joiner.onnx\ - [num_threads] + [num_threads [decoding_method]] + +Default value for num_threads is 2. +Valid values for decoding_method: greedy_search (default), modified_beam_search. Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html @@ -70,6 +73,11 @@ for a list of pre-trained models to download. config.model_config.num_threads = atoi(argv[5]); } + if (argc == 7) { + config.decoding_method = argv[6]; + } + config.max_active_paths = 4; + config.enable_endpoint = true; config.endpoint_config.rule1.min_trailing_silence = 2.4; diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 380368a8..0ce89afc 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -14,7 +14,7 @@ #include "sherpa-onnx/csrc/wave-reader.h" int main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 7) { + if (argc < 6 || argc > 8) { const char *usage = R"usage( Usage: ./bin/sherpa-onnx \ @@ -22,7 +22,10 @@ Usage: /path/to/encoder.onnx \ /path/to/decoder.onnx \ /path/to/joiner.onnx \ - /path/to/foo.wav [num_threads] + /path/to/foo.wav [num_threads [decoding_method]] + +Default value for num_threads is 2. +Valid values for decoding_method: greedy_search (default), modified_beam_search. Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html @@ -45,9 +48,15 @@ for a list of pre-trained models to download. std::string wav_filename = argv[5]; config.model_config.num_threads = 2; - if (argc == 7) { + if (argc == 7 && atoi(argv[6]) > 0) { config.model_config.num_threads = atoi(argv[6]); } + + if (argc == 8) { + config.decoding_method = argv[7]; + } + config.max_active_paths = 4; + fprintf(stderr, "%s\n", config.ToString().c_str()); sherpa_onnx::OnlineRecognizer recognizer(config); @@ -98,6 +107,7 @@ for a list of pre-trained models to download. 1000.; fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration; diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 73edbd10..88c484c2 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -1,12 +1,13 @@ include_directories(${CMAKE_SOURCE_DIR}) pybind11_add_module(_sherpa_onnx + display.cc + endpoint.cc features.cc + online-recognizer.cc + online-stream.cc online-transducer-model-config.cc sherpa-onnx.cc - endpoint.cc - online-stream.cc - online-recognizer.cc ) if(APPLE) diff --git a/sherpa-onnx/python/csrc/display.cc b/sherpa-onnx/python/csrc/display.cc new file mode 100644 index 00000000..44b76b5a --- /dev/null +++ b/sherpa-onnx/python/csrc/display.cc @@ -0,0 +1,18 @@ +// sherpa-onnx/python/csrc/display.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/display.h" + +#include "sherpa-onnx/csrc/display.h" + +namespace sherpa_onnx { + +void PybindDisplay(py::module *m) { + using PyClass = Display; + py::class_(*m, "Display") + .def(py::init(), py::arg("max_word_per_line") = 60) + .def("print", &PyClass::Print, py::arg("idx"), py::arg("s")); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/display.h b/sherpa-onnx/python/csrc/display.h new file mode 100644 index 00000000..23773306 --- /dev/null +++ b/sherpa-onnx/python/csrc/display.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/display.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ +#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindDisplay(py::module *m); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ diff --git a/sherpa-onnx/python/csrc/features.cc b/sherpa-onnx/python/csrc/features.cc index 6458f5cc..4c58b3eb 100644 --- a/sherpa-onnx/python/csrc/features.cc +++ b/sherpa-onnx/python/csrc/features.cc @@ -11,10 +11,12 @@ namespace sherpa_onnx { static void PybindFeatureExtractorConfig(py::module *m) { using PyClass = FeatureExtractorConfig; py::class_(*m, "FeatureExtractorConfig") - .def(py::init(), py::arg("sampling_rate") = 16000, - py::arg("feature_dim") = 80) + .def(py::init(), + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, + py::arg("max_feature_vectors") = -1) .def_readwrite("sampling_rate", &PyClass::sampling_rate) .def_readwrite("feature_dim", &PyClass::feature_dim) + .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index ba31b783..d5cb70f6 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), + bool, const std::string &, int32_t>(), py::arg("feat_config"), py::arg("model_config"), - py::arg("endpoint_config"), py::arg("enable_endpoint")) + py::arg("endpoint_config"), py::arg("enable_endpoint"), + py::arg("decoding_method"), py::arg("max_active_paths")) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) + .def_readwrite("decoding_method", &PyClass::decoding_method) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 4d6a798c..5e5886d4 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/python/csrc/sherpa-onnx.h" +#include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" @@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOnlineStream(&m); PybindEndpoint(&m); PybindOnlineRecognizer(&m); + + PybindDisplay(&m); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index e50b5f2e..13a98c53 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,9 +1,3 @@ -from _sherpa_onnx import ( - EndpointConfig, - FeatureExtractorConfig, - OnlineRecognizerConfig, - OnlineStream, - OnlineTransducerModelConfig, -) +from _sherpa_onnx import Display from .online_recognizer import OnlineRecognizer diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 265f0d1e..e8fd64a7 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -32,6 +32,9 @@ class OnlineRecognizer(object): rule1_min_trailing_silence: int = 2.4, rule2_min_trailing_silence: int = 1.2, rule3_min_utterance_length: int = 20, + decoding_method: str = "greedy_search", + max_active_paths: int = 4, + max_feature_vectors: int = -1, ): """ Please refer to @@ -74,6 +77,14 @@ class OnlineRecognizer(object): Used only when enable_endpoint_detection is True. If the utterance length in seconds is larger than this value, we assume an endpoint is detected. + decoding_method: + Valid values are greedy_search, modified_beam_search. + max_active_paths: + Use only when decoding_method is modified_beam_search. It specifies + the maximum number of active paths during beam search. + max_feature_vectors: + Number of feature vectors to cache. -1 means to cache all feature + frames that have been processed. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -93,6 +104,7 @@ class OnlineRecognizer(object): feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, + max_feature_vectors=max_feature_vectors, ) endpoint_config = EndpointConfig( @@ -106,6 +118,8 @@ class OnlineRecognizer(object): model_config=model_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + max_active_paths=max_active_paths, ) self.recognizer = _Recognizer(recognizer_config)