260 lines
7.3 KiB
C++
260 lines
7.3 KiB
C++
// c-api-examples/asr-microphone-example/c-api-alsa.cc
|
|
// Copyright (c) 2022-2024 Xiaomi Corporation
|
|
|
|
#include <signal.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
|
|
#include <algorithm>
|
|
#include <cctype> // std::tolower
|
|
#include <cstdint>
|
|
#include <string>
|
|
|
|
#include "c-api-examples/asr-microphone-example/alsa.h"
|
|
|
|
// NOTE: You don't need to use cargs.h in your own project.
|
|
// We use it in this file to parse commandline arguments
|
|
#include "cargs.h" // NOLINT
|
|
#include "sherpa-onnx/c-api/c-api.h"
|
|
|
|
static struct cag_option options[] = {
|
|
{/*.identifier =*/'h',
|
|
/*.access_letters =*/"h",
|
|
/*.access_name =*/"help",
|
|
/*.value_name =*/"help",
|
|
/*.description =*/"Show help"},
|
|
{/*.identifier =*/'t',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"tokens",
|
|
/*.value_name =*/"tokens",
|
|
/*.description =*/"Tokens file"},
|
|
{/*.identifier =*/'e',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"encoder",
|
|
/*.value_name =*/"encoder",
|
|
/*.description =*/"Encoder ONNX file"},
|
|
{/*.identifier =*/'d',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"decoder",
|
|
/*.value_name =*/"decoder",
|
|
/*.description =*/"Decoder ONNX file"},
|
|
{/*.identifier =*/'j',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"joiner",
|
|
/*.value_name =*/"joiner",
|
|
/*.description =*/"Joiner ONNX file"},
|
|
{/*.identifier =*/'n',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"num-threads",
|
|
/*.value_name =*/"num-threads",
|
|
/*.description =*/"Number of threads"},
|
|
{/*.identifier =*/'p',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"provider",
|
|
/*.value_name =*/"provider",
|
|
/*.description =*/"Provider: cpu (default), cuda, coreml"},
|
|
{/*.identifier =*/'m',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"decoding-method",
|
|
/*.value_name =*/"decoding-method",
|
|
/*.description =*/
|
|
"Decoding method: greedy_search (default), modified_beam_search"},
|
|
{/*.identifier =*/'f',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"hotwords-file",
|
|
/*.value_name =*/"hotwords-file",
|
|
/*.description =*/
|
|
"The file containing hotwords, one words/phrases per line, and for each "
|
|
"phrase the bpe/cjkchar are separated by a space. For example: ▁HE LL O "
|
|
"▁WORLD, 你 好 世 界"},
|
|
{/*.identifier =*/'s',
|
|
/*.access_letters =*/NULL,
|
|
/*.access_name =*/"hotwords-score",
|
|
/*.value_name =*/"hotwords-score",
|
|
/*.description =*/
|
|
"The bonus score for each token in hotwords. Used only when "
|
|
"decoding_method is modified_beam_search"},
|
|
};
|
|
|
|
const char *kUsage =
|
|
R"(
|
|
Usage:
|
|
./bin/c-api-alsa \
|
|
--tokens=/path/to/tokens.txt \
|
|
--encoder=/path/to/encoder.onnx \
|
|
--decoder=/path/to/decoder.onnx \
|
|
--joiner=/path/to/decoder.onnx \
|
|
device_name
|
|
|
|
The device name specifies which microphone to use in case there are several
|
|
on your system. You can use
|
|
|
|
arecord -l
|
|
|
|
to find all available microphones on your computer. For instance, if it outputs
|
|
|
|
**** List of CAPTURE Hardware Devices ****
|
|
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
|
|
Subdevices: 1/1
|
|
Subdevice #0: subdevice #0
|
|
|
|
and if you want to select card 3 and device 0 on that card, please use:
|
|
|
|
plughw:3,0
|
|
|
|
as the device_name.
|
|
)";
|
|
|
|
bool stop = false;
|
|
|
|
static void Handler(int sig) {
|
|
stop = true;
|
|
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
|
|
}
|
|
|
|
int32_t main(int32_t argc, char *argv[]) {
|
|
if (argc < 6) {
|
|
fprintf(stderr, "%s\n", kUsage);
|
|
exit(0);
|
|
}
|
|
|
|
signal(SIGINT, Handler);
|
|
|
|
SherpaOnnxOnlineRecognizerConfig config;
|
|
memset(&config, 0, sizeof(config));
|
|
|
|
config.model_config.debug = 0;
|
|
config.model_config.num_threads = 1;
|
|
config.model_config.provider = "cpu";
|
|
|
|
config.decoding_method = "greedy_search";
|
|
|
|
config.max_active_paths = 4;
|
|
|
|
config.feat_config.sample_rate = 16000;
|
|
config.feat_config.feature_dim = 80;
|
|
|
|
config.enable_endpoint = 1;
|
|
config.rule1_min_trailing_silence = 2.4;
|
|
config.rule2_min_trailing_silence = 1.2;
|
|
config.rule3_min_utterance_length = 300;
|
|
|
|
cag_option_context context;
|
|
char identifier;
|
|
const char *value;
|
|
|
|
cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv);
|
|
|
|
while (cag_option_fetch(&context)) {
|
|
identifier = cag_option_get(&context);
|
|
value = cag_option_get_value(&context);
|
|
switch (identifier) {
|
|
case 't':
|
|
config.model_config.tokens = value;
|
|
break;
|
|
case 'e':
|
|
config.model_config.transducer.encoder = value;
|
|
break;
|
|
case 'd':
|
|
config.model_config.transducer.decoder = value;
|
|
break;
|
|
case 'j':
|
|
config.model_config.transducer.joiner = value;
|
|
break;
|
|
case 'n':
|
|
config.model_config.num_threads = atoi(value);
|
|
break;
|
|
case 'p':
|
|
config.model_config.provider = value;
|
|
break;
|
|
case 'm':
|
|
config.decoding_method = value;
|
|
break;
|
|
case 'f':
|
|
config.hotwords_file = value;
|
|
break;
|
|
case 's':
|
|
config.hotwords_score = atof(value);
|
|
break;
|
|
case 'h': {
|
|
fprintf(stderr, "%s\n", kUsage);
|
|
exit(0);
|
|
break;
|
|
}
|
|
default:
|
|
// do nothing as config already has valid default values
|
|
break;
|
|
}
|
|
}
|
|
|
|
const SherpaOnnxOnlineRecognizer *recognizer =
|
|
SherpaOnnxCreateOnlineRecognizer(&config);
|
|
const SherpaOnnxOnlineStream *stream =
|
|
SherpaOnnxCreateOnlineStream(recognizer);
|
|
|
|
const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50);
|
|
int32_t segment_id = 0;
|
|
|
|
const char *device_name = argv[context.index];
|
|
sherpa_onnx::Alsa alsa(device_name);
|
|
fprintf(stderr, "Use recording device: %s\n", device_name);
|
|
fprintf(stderr,
|
|
"Please \033[32m\033[1mspeak\033[0m! Press \033[31m\033[1mCtrl + "
|
|
"C\033[0m to exit\n");
|
|
|
|
int32_t expected_sample_rate = 16000;
|
|
|
|
if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
|
|
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
|
|
expected_sample_rate);
|
|
exit(-1);
|
|
}
|
|
|
|
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
|
|
|
|
std::string last_text;
|
|
|
|
int32_t segment_index = 0;
|
|
|
|
while (!stop) {
|
|
const std::vector<float> &samples = alsa.Read(chunk);
|
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, expected_sample_rate,
|
|
samples.data(), samples.size());
|
|
while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) {
|
|
SherpaOnnxDecodeOnlineStream(recognizer, stream);
|
|
}
|
|
|
|
const SherpaOnnxOnlineRecognizerResult *r =
|
|
SherpaOnnxGetOnlineStreamResult(recognizer, stream);
|
|
|
|
std::string text = r->text;
|
|
SherpaOnnxDestroyOnlineRecognizerResult(r);
|
|
|
|
if (!text.empty() && last_text != text) {
|
|
last_text = text;
|
|
|
|
std::transform(text.begin(), text.end(), text.begin(),
|
|
[](auto c) { return std::tolower(c); });
|
|
|
|
SherpaOnnxPrint(display, segment_index, text.c_str());
|
|
fflush(stderr);
|
|
}
|
|
|
|
if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) {
|
|
if (!text.empty()) {
|
|
++segment_index;
|
|
}
|
|
SherpaOnnxOnlineStreamReset(recognizer, stream);
|
|
}
|
|
}
|
|
|
|
// free allocated resources
|
|
SherpaOnnxDestroyDisplay(display);
|
|
SherpaOnnxDestroyOnlineStream(stream);
|
|
SherpaOnnxDestroyOnlineRecognizer(recognizer);
|
|
fprintf(stderr, "\n");
|
|
|
|
return 0;
|
|
}
|