Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

1
.gitignore vendored
View File

@@ -36,4 +36,5 @@ tokens.txt
*.onnx *.onnx
log.txt log.txt
tags tags
run-decode-file-python.sh
android/SherpaOnnx/app/src/main/assets/ android/SherpaOnnx/app/src/main/assets/

View File

@@ -19,14 +19,16 @@ const char *kUsage =
" /path/to/encoder.onnx \\\n" " /path/to/encoder.onnx \\\n"
" /path/to/decoder.onnx \\\n" " /path/to/decoder.onnx \\\n"
" /path/to/joiner.onnx \\\n" " /path/to/joiner.onnx \\\n"
" /path/to/foo.wav [num_threads]\n" " /path/to/foo.wav [num_threads [decoding_method]]\n"
"\n\n" "\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Please refer to \n" "Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n"; "for a list of pre-trained models to download.\n";
int32_t main(int32_t argc, char *argv[]) { int32_t main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) { if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage); fprintf(stderr, "%s\n", kUsage);
return -1; return -1;
} }
@@ -36,13 +38,20 @@ int32_t main(int32_t argc, char *argv[]) {
config.model_config.decoder = argv[3]; config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4]; config.model_config.joiner = argv[4];
int32_t num_threads = 4; int32_t num_threads = 1;
if (argc == 7 && atoi(argv[6]) > 0) { if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]); num_threads = atoi(argv[6]);
} }
config.model_config.num_threads = num_threads; config.model_config.num_threads = num_threads;
config.model_config.debug = 0; config.model_config.debug = 0;
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.feat_config.sample_rate = 16000; config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80; config.feat_config.feature_dim = 80;
@@ -54,6 +63,9 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
const char *wav_filename = argv[5]; const char *wav_filename = argv[5];
FILE *fp = fopen(wav_filename, "rb"); FILE *fp = fopen(wav_filename, "rb");
if (!fp) { if (!fp) {
@@ -84,9 +96,18 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r = SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream); GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) { if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text); SherpaOnnxPrint(display, segment_id, r->text);
} }
if (IsEndpoint(recognizer, stream)) {
if (strlen(r->text)) {
++segment_id;
}
Reset(recognizer, stream);
}
DestroyOnlineRecognizerResult(r); DestroyOnlineRecognizerResult(r);
} }
} }
@@ -103,14 +124,17 @@ int32_t main(int32_t argc, char *argv[]) {
SherpaOnnxOnlineRecognizerResult *r = SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream); GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) { if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text); SherpaOnnxPrint(display, segment_id, r->text);
} }
DestroyOnlineRecognizerResult(r); DestroyOnlineRecognizerResult(r);
DestroyDisplay(display);
DestoryOnlineStream(stream); DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer); DestroyOnlineRecognizer(recognizer);
fprintf(stderr, "\n");
return 0; return 0;
} }

View File

@@ -26,12 +26,17 @@ if [ ! -f ./sherpa-onnx-ffmpeg ]; then
make make
fi fi
../ffmpeg-examples/sherpa-onnx-ffmpeg \ for method in greedy_search modified_beam_search; do
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ echo "test method: $method"
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ ../ffmpeg-examples/sherpa-onnx-ffmpeg \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/4.wav ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \
2 \
$method
done
echo "Decoding a URL" echo "Decoding a URL"

View File

@@ -7,7 +7,6 @@
#include "sherpa-onnx/c-api/c-api.h" #include "sherpa-onnx/c-api/c-api.h"
/* /*
* Copyright (c) 2010 Nicolas George * Copyright (c) 2010 Nicolas George
* Copyright (c) 2011 Stefano Sabatini * Copyright (c) 2011 Stefano Sabatini
@@ -43,14 +42,15 @@
#include <unistd.h> #include <unistd.h>
extern "C" { extern "C" {
#include <libavcodec/avcodec.h> #include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libavfilter/buffersink.h> #include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h> #include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h>
#include <libavutil/channel_layout.h> #include <libavutil/channel_layout.h>
#include <libavutil/opt.h> #include <libavutil/opt.h>
} }
static const char *filter_descr = "aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono"; static const char *filter_descr =
"aresample=16000,aformat=sample_fmts=s16:channel_layouts=mono";
static AVFormatContext *fmt_ctx; static AVFormatContext *fmt_ctx;
static AVCodecContext *dec_ctx; static AVCodecContext *dec_ctx;
@@ -59,308 +59,172 @@ AVFilterContext *buffersrc_ctx;
AVFilterGraph *filter_graph; AVFilterGraph *filter_graph;
static int audio_stream_index = -1; static int audio_stream_index = -1;
static int open_input_file(const char *filename) static int open_input_file(const char *filename) {
{ const AVCodec *dec;
const AVCodec *dec; int ret;
int ret;
if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) { if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename); av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);
return ret; return ret;
} }
if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) { if ((ret = avformat_find_stream_info(fmt_ctx, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n"); av_log(NULL, AV_LOG_ERROR, "Cannot find stream information\n");
return ret; return ret;
} }
/* select the audio stream */ /* select the audio stream */
ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0); ret = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &dec, 0);
if (ret < 0) { if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot find an audio stream in the input file\n"); av_log(NULL, AV_LOG_ERROR,
return ret; "Cannot find an audio stream in the input file\n");
} return ret;
audio_stream_index = ret; }
audio_stream_index = ret;
/* create decoding context */ /* create decoding context */
dec_ctx = avcodec_alloc_context3(dec); dec_ctx = avcodec_alloc_context3(dec);
if (!dec_ctx) if (!dec_ctx) return AVERROR(ENOMEM);
return AVERROR(ENOMEM); avcodec_parameters_to_context(dec_ctx,
avcodec_parameters_to_context(dec_ctx, fmt_ctx->streams[audio_stream_index]->codecpar); fmt_ctx->streams[audio_stream_index]->codecpar);
/* init the audio decoder */ /* init the audio decoder */
if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) { if ((ret = avcodec_open2(dec_ctx, dec, NULL)) < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n"); av_log(NULL, AV_LOG_ERROR, "Cannot open audio decoder\n");
return ret; return ret;
} }
return 0; return 0;
} }
static int init_filters(const char *filters_descr) static int init_filters(const char *filters_descr) {
{ char args[512];
char args[512]; int ret = 0;
int ret = 0; const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
const AVFilter *abuffersrc = avfilter_get_by_name("abuffer"); const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
const AVFilter *abuffersink = avfilter_get_by_name("abuffersink"); AVFilterInOut *outputs = avfilter_inout_alloc();
AVFilterInOut *outputs = avfilter_inout_alloc(); AVFilterInOut *inputs = avfilter_inout_alloc();
AVFilterInOut *inputs = avfilter_inout_alloc(); static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16,
static const enum AVSampleFormat out_sample_fmts[] = { AV_SAMPLE_FMT_S16, AV_SAMPLE_FMT_NONE }; AV_SAMPLE_FMT_NONE};
static const int out_sample_rates[] = { 16000, -1 }; static const int out_sample_rates[] = {16000, -1};
const AVFilterLink *outlink; const AVFilterLink *outlink;
AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base; AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
filter_graph = avfilter_graph_alloc(); filter_graph = avfilter_graph_alloc();
if (!outputs || !inputs || !filter_graph) { if (!outputs || !inputs || !filter_graph) {
ret = AVERROR(ENOMEM); ret = AVERROR(ENOMEM);
goto end; goto end;
} }
/* buffer audio source: the decoded frames from the decoder will be inserted here. */ /* buffer audio source: the decoded frames from the decoder will be inserted
if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC) * here. */
av_channel_layout_default(&dec_ctx->ch_layout, dec_ctx->ch_layout.nb_channels); if (dec_ctx->ch_layout.order == AV_CHANNEL_ORDER_UNSPEC)
ret = snprintf(args, sizeof(args), av_channel_layout_default(&dec_ctx->ch_layout,
"time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=", dec_ctx->ch_layout.nb_channels);
time_base.num, time_base.den, dec_ctx->sample_rate, ret = snprintf(args, sizeof(args),
av_get_sample_fmt_name(dec_ctx->sample_fmt)); "time_base=%d/%d:sample_rate=%d:sample_fmt=%s:channel_layout=",
av_channel_layout_describe(&dec_ctx->ch_layout, args + ret, sizeof(args) - ret); time_base.num, time_base.den, dec_ctx->sample_rate,
ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", av_get_sample_fmt_name(dec_ctx->sample_fmt));
args, NULL, filter_graph); av_channel_layout_describe(&dec_ctx->ch_layout, args + ret,
if (ret < 0) { sizeof(args) - ret);
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n"); ret = avfilter_graph_create_filter(&buffersrc_ctx, abuffersrc, "in", args,
goto end; NULL, filter_graph);
} if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer source\n");
goto end;
}
/* buffer audio sink: to terminate the filter chain. */ /* buffer audio sink: to terminate the filter chain. */
ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", ret = avfilter_graph_create_filter(&buffersink_ctx, abuffersink, "out", NULL,
NULL, NULL, filter_graph); NULL, filter_graph);
if (ret < 0) { if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n"); av_log(NULL, AV_LOG_ERROR, "Cannot create audio buffer sink\n");
goto end; goto end;
} }
ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1, ret = av_opt_set_int_list(buffersink_ctx, "sample_fmts", out_sample_fmts, -1,
AV_OPT_SEARCH_CHILDREN); AV_OPT_SEARCH_CHILDREN);
if (ret < 0) { if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n"); av_log(NULL, AV_LOG_ERROR, "Cannot set output sample format\n");
goto end; goto end;
} }
ret = av_opt_set(buffersink_ctx, "ch_layouts", "mono", ret =
AV_OPT_SEARCH_CHILDREN); av_opt_set(buffersink_ctx, "ch_layouts", "mono", AV_OPT_SEARCH_CHILDREN);
if (ret < 0) { if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n"); av_log(NULL, AV_LOG_ERROR, "Cannot set output channel layout\n");
goto end; goto end;
} }
ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates, -1, ret = av_opt_set_int_list(buffersink_ctx, "sample_rates", out_sample_rates,
AV_OPT_SEARCH_CHILDREN); -1, AV_OPT_SEARCH_CHILDREN);
if (ret < 0) { if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n"); av_log(NULL, AV_LOG_ERROR, "Cannot set output sample rate\n");
goto end; goto end;
} }
/* /*
* Set the endpoints for the filter graph. The filter_graph will * Set the endpoints for the filter graph. The filter_graph will
* be linked to the graph described by filters_descr. * be linked to the graph described by filters_descr.
*/ */
/* /*
* The buffer source output must be connected to the input pad of * The buffer source output must be connected to the input pad of
* the first filter described by filters_descr; since the first * the first filter described by filters_descr; since the first
* filter input label is not specified, it is set to "in" by * filter input label is not specified, it is set to "in" by
* default. * default.
*/ */
outputs->name = av_strdup("in"); outputs->name = av_strdup("in");
outputs->filter_ctx = buffersrc_ctx; outputs->filter_ctx = buffersrc_ctx;
outputs->pad_idx = 0; outputs->pad_idx = 0;
outputs->next = NULL; outputs->next = NULL;
/* /*
* The buffer sink input must be connected to the output pad of * The buffer sink input must be connected to the output pad of
* the last filter described by filters_descr; since the last * the last filter described by filters_descr; since the last
* filter output label is not specified, it is set to "out" by * filter output label is not specified, it is set to "out" by
* default. * default.
*/ */
inputs->name = av_strdup("out"); inputs->name = av_strdup("out");
inputs->filter_ctx = buffersink_ctx; inputs->filter_ctx = buffersink_ctx;
inputs->pad_idx = 0; inputs->pad_idx = 0;
inputs->next = NULL; inputs->next = NULL;
if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, if ((ret = avfilter_graph_parse_ptr(filter_graph, filters_descr, &inputs,
&inputs, &outputs, NULL)) < 0) &outputs, NULL)) < 0)
goto end; goto end;
if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) if ((ret = avfilter_graph_config(filter_graph, NULL)) < 0) goto end;
goto end;
/* Print summary of the sink buffer /* Print summary of the sink buffer
* Note: args buffer is reused to store channel layout string */ * Note: args buffer is reused to store channel layout string */
outlink = buffersink_ctx->inputs[0]; outlink = buffersink_ctx->inputs[0];
av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args)); av_channel_layout_describe(&outlink->ch_layout, args, sizeof(args));
av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n", av_log(NULL, AV_LOG_INFO, "Output: srate:%dHz fmt:%s chlayout:%s\n",
(int)outlink->sample_rate, (int)outlink->sample_rate,
(char *)av_x_if_null(av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"), (char *)av_x_if_null(
args); av_get_sample_fmt_name((AVSampleFormat)outlink->format), "?"),
args);
end: end:
avfilter_inout_free(&inputs); avfilter_inout_free(&inputs);
avfilter_inout_free(&outputs); avfilter_inout_free(&outputs);
return ret; return ret;
} }
static void sherpa_decode_frame(const AVFrame *frame, SherpaOnnxOnlineRecognizer *recognizer, static void sherpa_decode_frame(const AVFrame *frame,
SherpaOnnxOnlineStream* stream) SherpaOnnxOnlineRecognizer *recognizer,
{ SherpaOnnxOnlineStream *stream,
SherpaOnnxDisplay *display,
int32_t *segment_id) {
#define N 3200 // 100s. Sample rate is fixed to 16 kHz #define N 3200 // 100s. Sample rate is fixed to 16 kHz
static float samples[N]; static float samples[N];
static int nb_samples = 0; static int nb_samples = 0;
const int16_t *p = (int16_t*)frame->data[0]; const int16_t *p = (int16_t *)frame->data[0];
if (frame->nb_samples + nb_samples > N) {
AcceptWaveform(stream, 16000, samples, nb_samples);
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}
if (IsEndpoint(recognizer, stream)) {
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text);
}
DestroyOnlineRecognizerResult(r);
Reset(recognizer, stream);
}
nb_samples = 0;
}
for (int i = 0; i < frame->nb_samples; i++) {
samples[nb_samples++] = p[i] / 32768.;
}
}
static inline char *__av_err2str(int errnum)
{
static char str[AV_ERROR_MAX_STRING_SIZE];
memset(str, 0, sizeof(str));
return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
}
int main(int argc, char **argv)
{
int ret;
int num_threads = 4;
AVPacket *packet = av_packet_alloc();
AVFrame *frame = av_frame_alloc();
AVFrame *filt_frame = av_frame_alloc();
const char *kUsage =
"\n"
"Usage:\n"
" ./sherpa-onnx-ffmpeg \\\n"
" /path/to/tokens.txt \\\n"
" /path/to/encoder.onnx\\\n"
" /path/to/decoder.onnx\\\n"
" /path/to/joiner.onnx\\\n"
" /path/to/foo.wav [num_threads]"
"\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
if (!packet || !frame || !filt_frame) {
fprintf(stderr, "Could not allocate frame or packet\n");
exit(1);
}
if (argc < 6 || argc > 7) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
SherpaOnnxOnlineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.encoder = argv[2];
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
config.enable_endpoint = 1;
config.rule1_min_trailing_silence = 2.4;
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
if ((ret = open_input_file(argv[5])) < 0)
exit(1);
if ((ret = init_filters(filter_descr)) < 0)
exit(1);
/* read all packets */
while (1) {
if ((ret = av_read_frame(fmt_ctx, packet)) < 0)
break;
if (packet->stream_index == audio_stream_index) {
ret = avcodec_send_packet(dec_ctx, packet);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while sending a packet to the decoder\n");
break;
}
while (ret >= 0) {
ret = avcodec_receive_frame(dec_ctx, frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else if (ret < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while receiving a frame from the decoder\n");
exit(1);
}
if (ret >= 0) {
/* push the audio data from decoded frame into the filtergraph */
if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame, AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
av_log(NULL, AV_LOG_ERROR, "Error while feeding the audio filtergraph\n");
break;
}
/* pull filtered audio from the filtergraph */
while (1) {
ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)
break;
if (ret < 0)
exit(1);
sherpa_decode_frame(filt_frame, recognizer, stream);
av_frame_unref(filt_frame);
}
av_frame_unref(frame);
}
}
}
av_packet_unref(packet);
}
// add some tail padding
float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
AcceptWaveform(stream, 16000, tail_paddings, 4800);
InputFinished(stream);
if (frame->nb_samples + nb_samples > N) {
AcceptWaveform(stream, 16000, samples, nb_samples);
while (IsOnlineStreamReady(recognizer, stream)) { while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream); DecodeOnlineStream(recognizer, stream);
} }
@@ -368,25 +232,180 @@ int main(int argc, char **argv)
SherpaOnnxOnlineRecognizerResult *r = SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream); GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) { if (strlen(r->text)) {
fprintf(stderr, "%s\n", r->text); SherpaOnnxPrint(display, *segment_id, r->text);
}
if (IsEndpoint(recognizer, stream)) {
if (strlen(r->text)) {
++*segment_id;
}
Reset(recognizer, stream);
} }
DestroyOnlineRecognizerResult(r); DestroyOnlineRecognizerResult(r);
nb_samples = 0;
}
DestoryOnlineStream(stream); for (int i = 0; i < frame->nb_samples; i++) {
DestroyOnlineRecognizer(recognizer); samples[nb_samples++] = p[i] / 32768.;
}
avfilter_graph_free(&filter_graph); }
avcodec_free_context(&dec_ctx);
avformat_close_input(&fmt_ctx); static inline char *__av_err2str(int errnum) {
av_packet_free(&packet); static char str[AV_ERROR_MAX_STRING_SIZE];
av_frame_free(&frame); memset(str, 0, sizeof(str));
av_frame_free(&filt_frame); return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
}
if (ret < 0 && ret != AVERROR_EOF) {
fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret)); int main(int argc, char **argv) {
exit(1); int ret;
} int num_threads = 1;
AVPacket *packet = av_packet_alloc();
return 0; AVFrame *frame = av_frame_alloc();
AVFrame *filt_frame = av_frame_alloc();
const char *kUsage =
"\n"
"Usage:\n"
" ./sherpa-onnx-ffmpeg \\\n"
" /path/to/tokens.txt \\\n"
" /path/to/encoder.onnx\\\n"
" /path/to/decoder.onnx\\\n"
" /path/to/joiner.onnx\\\n"
" /path/to/foo.wav [num_threads [decoding_method]]"
"\n\n"
"Default num_threads is 1.\n"
"Valid decoding_method: greedy_search (default), modified_beam_search\n\n"
"Please refer to \n"
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n"
"for a list of pre-trained models to download.\n";
if (!packet || !frame || !filt_frame) {
fprintf(stderr, "Could not allocate frame or packet\n");
exit(1);
}
if (argc < 6 || argc > 8) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
SherpaOnnxOnlineRecognizerConfig config;
config.model_config.tokens = argv[1];
config.model_config.encoder = argv[2];
config.model_config.decoder = argv[3];
config.model_config.joiner = argv[4];
if (argc == 7 && atoi(argv[6]) > 0) {
num_threads = atoi(argv[6]);
}
config.model_config.num_threads = num_threads;
config.model_config.debug = 0;
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
config.decoding_method = "greedy_search";
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = 1;
config.rule1_min_trailing_silence = 2.4;
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;
SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config);
SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);
SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;
if ((ret = open_input_file(argv[5])) < 0) exit(1);
if ((ret = init_filters(filter_descr)) < 0) exit(1);
/* read all packets */
while (1) {
if ((ret = av_read_frame(fmt_ctx, packet)) < 0) break;
if (packet->stream_index == audio_stream_index) {
ret = avcodec_send_packet(dec_ctx, packet);
if (ret < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while sending a packet to the decoder\n");
break;
}
while (ret >= 0) {
ret = avcodec_receive_frame(dec_ctx, frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else if (ret < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while receiving a frame from the decoder\n");
exit(1);
}
if (ret >= 0) {
/* push the audio data from decoded frame into the filtergraph */
if (av_buffersrc_add_frame_flags(buffersrc_ctx, frame,
AV_BUFFERSRC_FLAG_KEEP_REF) < 0) {
av_log(NULL, AV_LOG_ERROR,
"Error while feeding the audio filtergraph\n");
break;
}
/* pull filtered audio from the filtergraph */
while (1) {
ret = av_buffersink_get_frame(buffersink_ctx, filt_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) break;
if (ret < 0) exit(1);
sherpa_decode_frame(filt_frame, recognizer, stream, display,
&segment_id);
av_frame_unref(filt_frame);
}
av_frame_unref(frame);
}
}
}
av_packet_unref(packet);
}
// add some tail padding
float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
AcceptWaveform(stream, 16000, tail_paddings, 4800);
InputFinished(stream);
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}
SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);
if (strlen(r->text)) {
SherpaOnnxPrint(display, segment_id, r->text);
}
DestroyOnlineRecognizerResult(r);
DestroyDisplay(display);
DestoryOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
avfilter_graph_free(&filter_graph);
avcodec_free_context(&dec_ctx);
avformat_close_input(&fmt_ctx);
av_packet_free(&packet);
av_frame_free(&frame);
av_frame_free(&filt_frame);
if (ret < 0 && ret != AVERROR_EOF) {
fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
exit(1);
}
fprintf(stderr, "\n");
return 0;
} }

View File

@@ -53,6 +53,20 @@ def get_args():
help="Path to the joiner model", help="Path to the joiner model",
) )
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument( parser.add_argument(
"--wave-filename", "--wave-filename",
type=str, type=str,
@@ -65,7 +79,6 @@ def get_args():
def main(): def main():
sample_rate = 16000 sample_rate = 16000
num_threads = 2
args = get_args() args = get_args()
assert_file_exists(args.encoder) assert_file_exists(args.encoder)
@@ -81,9 +94,10 @@ def main():
encoder=args.encoder, encoder=args.encoder,
decoder=args.decoder, decoder=args.decoder,
joiner=args.joiner, joiner=args.joiner,
num_threads=num_threads, num_threads=args.num_threads,
sample_rate=sample_rate, sample_rate=sample_rate,
feature_dim=80, feature_dim=80,
decoding_method=args.decoding_method,
) )
with wave.open(args.wave_filename) as f: with wave.open(args.wave_filename) as f:
assert f.getframerate() == sample_rate, f.getframerate() assert f.getframerate() == sample_rate, f.getframerate()
@@ -119,7 +133,8 @@ def main():
end_time = time.time() end_time = time.time()
elapsed_seconds = end_time - start_time elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration rtf = elapsed_seconds / duration
print(f"num_threads: {num_threads}") print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s") print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s") print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")

View File

@@ -60,10 +60,10 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--wave-filename", "--decoding-method",
type=str, type=str,
help="""Path to the wave filename. Must be 16 kHz, default="greedy_search",
mono with 16-bit samples""", help="Valid values are greedy_search and modified_beam_search",
) )
return parser.parse_args() return parser.parse_args()
@@ -83,17 +83,23 @@ def create_recognizer():
encoder=args.encoder, encoder=args.encoder,
decoder=args.decoder, decoder=args.decoder,
joiner=args.joiner, joiner=args.joiner,
num_threads=1,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True, enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4, rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2, rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
) )
return recognizer return recognizer
def main(): def main():
print("Started! Please speak")
recognizer = create_recognizer() recognizer = create_recognizer()
print("Started! Please speak")
sample_rate = 16000 sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = "" last_result = ""
@@ -101,6 +107,7 @@ def main():
last_result = "" last_result = ""
segment_id = 0 segment_id = 0
display = sherpa_onnx.Display(max_word_per_line=30)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True: while True:
samples, _ = s.read(samples_per_read) # a blocking read samples, _ = s.read(samples_per_read) # a blocking read
@@ -115,7 +122,7 @@ def main():
if result and (last_result != result): if result and (last_result != result):
last_result = result last_result = result
print(f"{segment_id}: {result}") display.print(segment_id, result)
if is_endpoint: if is_endpoint:
if result: if result:

View File

@@ -59,10 +59,10 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--wave-filename", "--decoding-method",
type=str, type=str,
help="""Path to the wave filename. Must be 16 kHz, default="greedy_search",
mono with 16-bit samples""", help="Valid values are greedy_search and modified_beam_search",
) )
return parser.parse_args() return parser.parse_args()
@@ -82,9 +82,11 @@ def create_recognizer():
encoder=args.encoder, encoder=args.encoder,
decoder=args.decoder, decoder=args.decoder,
joiner=args.joiner, joiner=args.joiner,
num_threads=4, num_threads=1,
sample_rate=16000, sample_rate=16000,
feature_dim=80, feature_dim=80,
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
) )
return recognizer return recognizer
@@ -96,6 +98,7 @@ def main():
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = "" last_result = ""
stream = recognizer.create_stream() stream = recognizer.create_stream()
display = sherpa_onnx.Display(max_word_per_line=40)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True: while True:
samples, _ = s.read(samples_per_read) # a blocking read samples, _ = s.read(samples_per_read) # a blocking read
@@ -106,7 +109,7 @@ def main():
result = recognizer.get_result(stream) result = recognizer.get_result(stream)
if last_result != result: if last_result != result:
last_result = result last_result = result
print(result) display.print(-1, result)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-recognizer.h"
struct SherpaOnnxOnlineRecognizer { struct SherpaOnnxOnlineRecognizer {
@@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
: impl(std::move(p)) {} : impl(std::move(p)) {}
}; };
struct SherpaOnnxDisplay {
std::unique_ptr<sherpa_onnx::Display> impl;
};
SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) { const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config; sherpa_onnx::OnlineRecognizerConfig recognizer_config;
@@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.model_config.num_threads = config->model_config.num_threads; recognizer_config.model_config.num_threads = config->model_config.num_threads;
recognizer_config.model_config.debug = config->model_config.debug; recognizer_config.model_config.debug = config->model_config.debug;
recognizer_config.decoding_method = config->decoding_method;
recognizer_config.max_active_paths = config->max_active_paths;
recognizer_config.enable_endpoint = config->enable_endpoint; recognizer_config.enable_endpoint = config->enable_endpoint;
recognizer_config.endpoint_config.rule1.min_trailing_silence = recognizer_config.endpoint_config.rule1.min_trailing_silence =
@@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) { SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsEndpoint(stream->impl.get()); return recognizer->impl->IsEndpoint(stream->impl.get());
} }
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
return ans;
}
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
display->impl->Print(idx, s);
}

View File

@@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
SherpaOnnxFeatureConfig feat_config; SherpaOnnxFeatureConfig feat_config;
SherpaOnnxOnlineTransducerModelConfig model_config; SherpaOnnxOnlineTransducerModelConfig model_config;
/// Possible values are: greedy_search, modified_beam_search
const char *decoding_method;
/// Used only when decoding_method is modified_beam_search
/// Example value: 4
int32_t max_active_paths;
/// 0 to disable endpoint detection. /// 0 to disable endpoint detection.
/// A non-zero value to enable endpoint detection. /// A non-zero value to enable endpoint detection.
int32_t enable_endpoint; int32_t enable_endpoint;
@@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream); SherpaOnnxOnlineStream *stream);
// for displaying results on Linux/macOS.
typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
/// Create a display object. Must be freed using DestroyDisplay to avoid
/// memory leak.
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
void DestroyDisplay(SherpaOnnxDisplay *display);
/// Print the result.
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
#ifdef __cplusplus #ifdef __cplusplus
} /* extern "C" */ } /* extern "C" */
#endif #endif

View File

@@ -9,10 +9,11 @@ set(sources
online-lstm-transducer-model.cc online-lstm-transducer-model.cc
online-recognizer.cc online-recognizer.cc
online-stream.cc online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc online-transducer-model-config.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-model.cc online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-zipformer-transducer-model.cc online-zipformer-transducer-model.cc
onnx-utils.cc onnx-utils.cc
parse-options.cc parse-options.cc

View File

@@ -12,9 +12,16 @@ namespace sherpa_onnx {
class Display { class Display {
public: public:
explicit Display(int32_t max_word_per_line = 60)
: max_word_per_line_(max_word_per_line) {}
void Print(int32_t segment_id, const std::string &s) { void Print(int32_t segment_id, const std::string &s) {
#ifdef _MSC_VER #ifdef _MSC_VER
fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); if (segment_id != -1) {
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
} else {
fprintf(stderr, "%s\n", s.c_str());
}
return; return;
#endif #endif
if (last_segment_ == segment_id) { if (last_segment_ == segment_id) {
@@ -27,7 +34,9 @@ class Display {
num_previous_lines_ = 0; num_previous_lines_ = 0;
} }
fprintf(stderr, "\r%d:", segment_id); if (segment_id != -1) {
fprintf(stderr, "\r%d:", segment_id);
}
int32_t i = 0; int32_t i = 0;
for (size_t n = 0; n < s.size();) { for (size_t n = 0; n < s.size();) {
@@ -69,7 +78,7 @@ class Display {
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
private: private:
int32_t max_word_per_line_ = 60; int32_t max_word_per_line_;
int32_t num_previous_lines_ = 0; int32_t num_previous_lines_ = 0;
int32_t last_segment_ = -1; int32_t last_segment_ = -1;
}; };

View File

@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig("; os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", "; os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")"; os << "feature_dim=" << feature_dim << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
return os.str(); return os.str();
} }
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
opts_.frame_opts.snip_edges = false; opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate; opts_.frame_opts.samp_freq = config.sampling_rate;
// cache 100 seconds of feature frames, which is more than enough opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
// for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.mel_opts.num_bins = config.feature_dim; opts_.mel_opts.num_bins = config.feature_dim;

View File

@@ -16,6 +16,7 @@ namespace sherpa_onnx {
struct FeatureExtractorConfig { struct FeatureExtractorConfig {
float sampling_rate = 16000; float sampling_rate = 16000;
int32_t feature_dim = 80; int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
std::string ToString() const; std::string ToString() const;

View File

@@ -18,7 +18,7 @@ namespace sherpa_onnx {
struct Hypothesis { struct Hypothesis {
// The predicted tokens so far. Newly predicated tokens are appended. // The predicted tokens so far. Newly predicated tokens are appended.
std::vector<int32_t> ys; std::vector<int64_t> ys;
// timestamps[i] contains the frame number after subsampling // timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded. // on which ys[i] is decoded.
@@ -30,7 +30,7 @@ struct Hypothesis {
int32_t num_trailing_blanks = 0; int32_t num_trailing_blanks = 0;
Hypothesis() = default; Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob) Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {} : ys(ys), log_prob(log_prob) {}
// If two Hypotheses have the same `Key`, then they contain // If two Hypotheses have the same `Key`, then they contain

View File

@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it."); "True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths, po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search."); "beam size used in modified beam search.");
po->Register("decoding-mothod", &decoding_method, po->Register("decoding-method", &decoding_method,
"decoding method," "decoding method,"
"now support greedy_search and modified_beam_search."); "now support greedy_search and modified_beam_search.");
} }
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "feat_config=" << feat_config.ToString() << ", "; os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ","; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ","; os << "max_active_paths=" << max_active_paths << ", ";
os << "decoding_method=\"" << decoding_method << "\")"; os << "decoding_method=\"" << decoding_method << "\")";
return os.str(); return os.str();
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
} }
void Reset(OnlineStream *s) const { void Reset(OnlineStream *s) const {
// reset result, neural network model state, and // we keep the decoder_out
// the feature extractor state decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
// reset result
s->SetResult(decoder_->GetEmptyResult()); s->SetResult(decoder_->GetEmptyResult());
s->GetResult().decoder_out = std::move(decoder_out);
// reset neural network model state // Note: We only update counters. The underlying audio samples
s->SetStates(model_->GetEncoderInitStates()); // are not discarded.
// reset feature extractor
s->Reset(); s->Reset();
} }

View File

@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
OnlineTransducerModelConfig model_config; OnlineTransducerModelConfig model_config;
EndpointConfig endpoint_config; EndpointConfig endpoint_config;
bool enable_endpoint = true; bool enable_endpoint = true;
int32_t max_active_paths = 4;
std::string decoding_method = "modified_beam_search"; std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search // now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
OnlineRecognizerConfig() = default; OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config, const OnlineTransducerModelConfig &model_config,
const EndpointConfig &endpoint_config, const EndpointConfig &endpoint_config,
bool enable_endpoint) bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
endpoint_config(endpoint_config), endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {} enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -22,18 +22,21 @@ class OnlineStream::Impl {
void InputFinished() { feat_extractor_.InputFinished(); } void InputFinished() { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const { bool IsLastFrame(int32_t frame) const {
return feat_extractor_.IsLastFrame(frame); return feat_extractor_.IsLastFrame(frame);
} }
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
return feat_extractor_.GetFrames(frame_index, n); return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
} }
void Reset() { void Reset() {
feat_extractor_.Reset(); // we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0; num_processed_frames_ = 0;
} }
@@ -41,7 +44,7 @@ class OnlineStream::Impl {
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
const OnlineTransducerDecoderResult &GetResult() const { return result_; } OnlineTransducerDecoderResult &GetResult() { return result_; }
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
@@ -54,6 +57,7 @@ class OnlineStream::Impl {
private: private:
FeatureExtractor feat_extractor_; FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_; OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_; std::vector<Ort::Value> states_;
}; };
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r); impl_->SetResult(r);
} }
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult(); return impl_->GetResult();
} }

View File

@@ -63,7 +63,7 @@ class OnlineStream {
int32_t &GetNumProcessedFrames(); int32_t &GetNumProcessedFrames();
void SetResult(const OnlineTransducerDecoderResult &r); void SetResult(const OnlineTransducerDecoderResult &r);
const OnlineTransducerDecoderResult &GetResult() const; OnlineTransducerDecoderResult &GetResult();
void SetStates(std::vector<Ort::Value> states); void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates(); std::vector<Ort::Value> &GetStates();

View File

@@ -0,0 +1,60 @@
// sherpa-onnx/csrc/online-transducer-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
const OnlineTransducerDecoderResult &other)
: OnlineTransducerDecoderResult() {
*this = other;
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
const OnlineTransducerDecoderResult &other) {
if (this == &other) {
return *this;
}
tokens = other.tokens;
num_trailing_blanks = other.num_trailing_blanks;
Ort::AllocatorWithDefaultOptions allocator;
if (other.decoder_out) {
decoder_out = Clone(allocator, &other.decoder_out);
}
hyps = other.hyps;
return *this;
}
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
OnlineTransducerDecoderResult &&other)
: OnlineTransducerDecoderResult() {
*this = std::move(other);
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
OnlineTransducerDecoderResult &&other) {
if (this == &other) {
return *this;
}
tokens = std::move(other.tokens);
num_trailing_blanks = other.num_trailing_blanks;
decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps);
return *this;
}
} // namespace sherpa_onnx

View File

@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
/// number of trailing blank frames decoded so far /// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0; int32_t num_trailing_blanks = 0;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
// used only in modified beam_search // used only in modified beam_search
Hypotheses hyps; Hypotheses hyps;
OnlineTransducerDecoderResult()
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult &operator=(
const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult &operator=(
OnlineTransducerDecoderResult &&other);
}; };
class OnlineTransducerDecoder { class OnlineTransducerDecoder {
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
*/ */
virtual void Decode(Ort::Value encoder_out, virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0; std::vector<OnlineTransducerDecoderResult> *result) = 0;
// used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -13,6 +13,43 @@
namespace sherpa_onnx { namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<OnlineTransducerDecoderResult> &results,
Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
for (const auto &r : results) {
if (r.decoder_out) {
const float *src = r.decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
}
dst += shape[1];
}
}
static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *results) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};
const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *results) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}
float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}
OnlineTransducerDecoderResult OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize(); int32_t context_size = model_->ContextSize();
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(*result); Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
UseCachedDecoderOut(*result, &decoder_out);
for (int32_t t = 0; t != num_frames; ++t) { for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = Ort::Value cur_encoder_out =
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} }
} }
if (emitted) { if (emitted) {
decoder_input = model_->BuildDecoderInput(*result); Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input)); decoder_out = model_->RunDecoder(std::move(decoder_input));
} }
} }
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -13,6 +13,29 @@
namespace sherpa_onnx { namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<int32_t> &hyps_num_split,
const std::vector<OnlineTransducerDecoderResult> &results,
int32_t context_size, Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(results.size());
for (int32_t i = 0; i != batch_size; ++i) {
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
if (num_hyps > 1 || !results[i].decoder_out) {
dst += num_hyps * shape[1];
continue;
}
const float *src = results[i].decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
dst += shape[1];
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) { const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape = std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize(); int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0 int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r; OnlineTransducerDecoderResult r;
std::vector<int32_t> blanks(context_size, blank_id); std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}}); Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp); r.hyps = std::move(blank_hyp);
return r; return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(prev); Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
if (t == 0) {
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
&decoder_out);
}
Ort::Value cur_encoder_out = Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} }
for (int32_t b = 0; b != batch_size; ++b) { for (int32_t b = 0; b != batch_size; ++b) {
(*result)[b].hyps = std::move(cur[b]); auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
} }
} }
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
OnlineTransducerDecoderResult *result) {
if (result->tokens.size() == model_->ContextSize()) {
result->decoder_out = Ort::Value{nullptr};
return;
}
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out, void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override; std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private: private:
OnlineTransducerModel *model_; // Not owned OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_; int32_t max_active_paths_;

View File

@@ -21,7 +21,7 @@ static void Handler(int sig) {
} }
int main(int32_t argc, char *argv[]) { int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) { if (argc < 6 || argc > 8) {
const char *usage = R"usage( const char *usage = R"usage(
Usage: Usage:
./bin/sherpa-onnx-alsa \ ./bin/sherpa-onnx-alsa \
@@ -30,7 +30,10 @@ Usage:
/path/to/decoder.onnx \ /path/to/decoder.onnx \
/path/to/joiner.onnx \ /path/to/joiner.onnx \
device_name \ device_name \
[num_threads] [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -79,6 +82,11 @@ as the device_name.
config.model_config.num_threads = atoi(argv[6]); config.model_config.num_threads = atoi(argv[6]);
} }
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = true; config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4; config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
} }
int32_t main(int32_t argc, char *argv[]) { int32_t main(int32_t argc, char *argv[]) {
if (argc < 5 || argc > 6) { if (argc < 5 || argc > 7) {
const char *usage = R"usage( const char *usage = R"usage(
Usage: Usage:
./bin/sherpa-onnx-microphone \ ./bin/sherpa-onnx-microphone \
@@ -44,7 +44,10 @@ Usage:
/path/to/encoder.onnx\ /path/to/encoder.onnx\
/path/to/decoder.onnx\ /path/to/decoder.onnx\
/path/to/joiner.onnx\ /path/to/joiner.onnx\
[num_threads] [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -70,6 +73,11 @@ for a list of pre-trained models to download.
config.model_config.num_threads = atoi(argv[5]); config.model_config.num_threads = atoi(argv[5]);
} }
if (argc == 7) {
config.decoding_method = argv[6];
}
config.max_active_paths = 4;
config.enable_endpoint = true; config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4; config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -14,7 +14,7 @@
#include "sherpa-onnx/csrc/wave-reader.h" #include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) { int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) { if (argc < 6 || argc > 8) {
const char *usage = R"usage( const char *usage = R"usage(
Usage: Usage:
./bin/sherpa-onnx \ ./bin/sherpa-onnx \
@@ -22,7 +22,10 @@ Usage:
/path/to/encoder.onnx \ /path/to/encoder.onnx \
/path/to/decoder.onnx \ /path/to/decoder.onnx \
/path/to/joiner.onnx \ /path/to/joiner.onnx \
/path/to/foo.wav [num_threads] /path/to/foo.wav [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -45,9 +48,15 @@ for a list of pre-trained models to download.
std::string wav_filename = argv[5]; std::string wav_filename = argv[5];
config.model_config.num_threads = 2; config.model_config.num_threads = 2;
if (argc == 7) { if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]); config.model_config.num_threads = atoi(argv[6]);
} }
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
fprintf(stderr, "%s\n", config.ToString().c_str()); fprintf(stderr, "%s\n", config.ToString().c_str());
sherpa_onnx::OnlineRecognizer recognizer(config); sherpa_onnx::OnlineRecognizer recognizer(config);
@@ -98,6 +107,7 @@ for a list of pre-trained models to download.
1000.; 1000.;
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration; float rtf = elapsed_seconds / duration;

View File

@@ -1,12 +1,13 @@
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_sherpa_onnx pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc features.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc online-transducer-model-config.cc
sherpa-onnx.cc sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
) )
if(APPLE) if(APPLE)

View File

@@ -0,0 +1,18 @@
// sherpa-onnx/python/csrc/display.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/csrc/display.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m) {
using PyClass = Display;
py::class_<PyClass>(*m, "Display")
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/display.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_

View File

@@ -11,10 +11,12 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) { static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig; using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig") py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000, .def(py::init<float, int32_t, int32_t>(),
py::arg("feature_dim") = 80) py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("max_feature_vectors") = -1)
.def_readwrite("sampling_rate", &PyClass::sampling_rate) .def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim) .def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig") py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, .def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const EndpointConfig &, const OnlineTransducerModelConfig &, const EndpointConfig &,
bool>(), bool, const std::string &, int32_t>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint")) py::arg("endpoint_config"), py::arg("enable_endpoint"),
py::arg("decoding_method"), py::arg("max_active_paths"))
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h" #include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h" #include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOnlineStream(&m); PybindOnlineStream(&m);
PybindEndpoint(&m); PybindEndpoint(&m);
PybindOnlineRecognizer(&m); PybindOnlineRecognizer(&m);
PybindDisplay(&m);
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -1,9 +1,3 @@
from _sherpa_onnx import ( from _sherpa_onnx import Display
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
)
from .online_recognizer import OnlineRecognizer from .online_recognizer import OnlineRecognizer

View File

@@ -32,6 +32,9 @@ class OnlineRecognizer(object):
rule1_min_trailing_silence: int = 2.4, rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2, rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20, rule3_min_utterance_length: int = 20,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
max_feature_vectors: int = -1,
): ):
""" """
Please refer to Please refer to
@@ -74,6 +77,14 @@ class OnlineRecognizer(object):
Used only when enable_endpoint_detection is True. If the utterance Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint length in seconds is larger than this value, we assume an endpoint
is detected. is detected.
decoding_method:
Valid values are greedy_search, modified_beam_search.
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
max_feature_vectors:
Number of feature vectors to cache. -1 means to cache all feature
frames that have been processed.
""" """
_assert_file_exists(tokens) _assert_file_exists(tokens)
_assert_file_exists(encoder) _assert_file_exists(encoder)
@@ -93,6 +104,7 @@ class OnlineRecognizer(object):
feat_config = FeatureExtractorConfig( feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate, sampling_rate=sample_rate,
feature_dim=feature_dim, feature_dim=feature_dim,
max_feature_vectors=max_feature_vectors,
) )
endpoint_config = EndpointConfig( endpoint_config = EndpointConfig(
@@ -106,6 +118,8 @@ class OnlineRecognizer(object):
model_config=model_config, model_config=model_config,
endpoint_config=endpoint_config, endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection, enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)