re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)
Co-authored-by: xiao <shawl336@163.com>
This commit is contained in:
121
.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml
vendored
Normal file
121
.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
name: c-api-test-loading-tokens-hotwords-from-memory
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+*'
|
||||
paths:
|
||||
- '.github/workflows/c-api.yaml'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'c-api-examples/**'
|
||||
- 'ffmpeg-examples/**'
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/c-api.yaml'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'c-api-examples/**'
|
||||
- 'ffmpeg-examples/**'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: c-api-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
c_api:
|
||||
name: ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ${{ matrix.os }}-c-api-shared
|
||||
|
||||
- name: Build sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
|
||||
cmake --version
|
||||
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
cmake \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_SHARED_LIBS=ON \
|
||||
-D CMAKE_INSTALL_PREFIX=./install \
|
||||
-D SHERPA_ONNX_ENABLE_BINARY=OFF \
|
||||
..
|
||||
|
||||
make -j2 install
|
||||
|
||||
ls -lh install/lib
|
||||
ls -lh install/include
|
||||
|
||||
if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
|
||||
ldd ./install/lib/libsherpa-onnx-c-api.so
|
||||
echo "---"
|
||||
readelf -d ./install/lib/libsherpa-onnx-c-api.so
|
||||
fi
|
||||
|
||||
if [[ ${{ matrix.os }} == macos-latest ]]; then
|
||||
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
|
||||
fi
|
||||
|
||||
- name: Test streaming zipformer with tokens and hotwords loaded from buffers
|
||||
shell: bash
|
||||
run: |
|
||||
gcc -o streaming-zipformer-buffered-tokens-hotwords-c-api ./c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c \
|
||||
-I ./build/install/include \
|
||||
-L ./build/install/lib/ \
|
||||
-l sherpa-onnx-c-api \
|
||||
-l onnxruntime
|
||||
|
||||
ls -lh streaming-zipformer-buffered-tokens-hotwords-c-api
|
||||
|
||||
if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
|
||||
ldd ./streaming-zipformer-buffered-tokens-hotwords-c-api
|
||||
echo "----"
|
||||
readelf -d ./streaming-zipformer-buffered-tokens-hotwords-c-api
|
||||
fi
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model
|
||||
cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/
|
||||
rm bpe.model
|
||||
|
||||
printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt
|
||||
|
||||
ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17
|
||||
echo "---"
|
||||
ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs
|
||||
|
||||
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
|
||||
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
|
||||
|
||||
./streaming-zipformer-buffered-tokens-hotwords-c-api
|
||||
|
||||
rm -rf sherpa-onnx-streaming-zipformer-*
|
||||
@@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api)
|
||||
add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c)
|
||||
target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api)
|
||||
|
||||
add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api
|
||||
streaming-zipformer-buffered-tokens-hotwords-c-api.c)
|
||||
target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api)
|
||||
|
||||
if(SHERPA_ONNX_HAS_ALSA)
|
||||
add_subdirectory(./asr-microphone-example)
|
||||
elseif((UNIX AND NOT APPLE) OR LINUX)
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
// Copyright (c) 2024 Luo Xiao
|
||||
|
||||
//
|
||||
// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C
|
||||
// and with tokens and hotwords loaded from buffered strings instead of from external
|
||||
// files API.
|
||||
// clang-format off
|
||||
//
|
||||
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
|
||||
//
|
||||
// clang-format on
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "sherpa-onnx/c-api/c-api.h"
|
||||
|
||||
static size_t ReadFile(const char *filename, const char **buffer_out) {
|
||||
FILE *file = fopen(filename, "rb");
|
||||
if (file == NULL) {
|
||||
fprintf(stderr, "Failed to open %s\n", filename);
|
||||
return -1;
|
||||
}
|
||||
fseek(file, 0L, SEEK_END);
|
||||
long size = ftell(file);
|
||||
rewind(file);
|
||||
*buffer_out = malloc(size);
|
||||
if (*buffer_out == NULL) {
|
||||
fclose(file);
|
||||
fprintf(stderr, "Memory error\n");
|
||||
return -1;
|
||||
}
|
||||
size_t read_bytes = fread(*buffer_out, 1, size, file);
|
||||
if (read_bytes != size) {
|
||||
printf("Errors occured in reading the file %s\n", filename);
|
||||
free(*buffer_out);
|
||||
*buffer_out = NULL;
|
||||
fclose(file);
|
||||
return -1;
|
||||
}
|
||||
fclose(file);
|
||||
return read_bytes;
|
||||
}
|
||||
|
||||
int32_t main() {
|
||||
const char *wav_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav";
|
||||
const char *encoder_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
|
||||
"encoder-epoch-99-avg-1.onnx";
|
||||
const char *decoder_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
|
||||
"decoder-epoch-99-avg-1.onnx";
|
||||
const char *joiner_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
|
||||
"joiner-epoch-99-avg-1.onnx";
|
||||
const char *provider = "cpu";
|
||||
const char *modeling_unit = "bpe";
|
||||
const char *tokens_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt";
|
||||
const char *hotwords_filename =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt";
|
||||
const char *bpe_vocab =
|
||||
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
|
||||
"bpe.vocab";
|
||||
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
|
||||
if (wave == NULL) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// reading tokens and hotwords to buffers
|
||||
const char *tokens_buf;
|
||||
size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf);
|
||||
if (token_buf_size < 1) {
|
||||
fprintf(stderr, "Please check your tokens.txt!\n");
|
||||
free(tokens_buf);
|
||||
return -1;
|
||||
}
|
||||
const char *hotwords_buf;
|
||||
size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf);
|
||||
if (hotwords_buf_size < 1) {
|
||||
fprintf(stderr, "Please check your hotwords.txt!\n");
|
||||
free(hotwords_buf);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Zipformer config
|
||||
SherpaOnnxOnlineTransducerModelConfig zipformer_config;
|
||||
memset(&zipformer_config, 0, sizeof(zipformer_config));
|
||||
zipformer_config.encoder = encoder_filename;
|
||||
zipformer_config.decoder = decoder_filename;
|
||||
zipformer_config.joiner = joiner_filename;
|
||||
|
||||
// Online model config
|
||||
SherpaOnnxOnlineModelConfig online_model_config;
|
||||
memset(&online_model_config, 0, sizeof(online_model_config));
|
||||
online_model_config.debug = 1;
|
||||
online_model_config.num_threads = 1;
|
||||
online_model_config.provider = provider;
|
||||
online_model_config.tokens_buf = tokens_buf;
|
||||
online_model_config.tokens_buf_size = token_buf_size;
|
||||
online_model_config.transducer = zipformer_config;
|
||||
|
||||
// Recognizer config
|
||||
SherpaOnnxOnlineRecognizerConfig recognizer_config;
|
||||
memset(&recognizer_config, 0, sizeof(recognizer_config));
|
||||
recognizer_config.decoding_method = "modified_beam_search";
|
||||
recognizer_config.model_config = online_model_config;
|
||||
recognizer_config.hotwords_buf = hotwords_buf;
|
||||
recognizer_config.hotwords_buf_size = hotwords_buf_size;
|
||||
|
||||
SherpaOnnxOnlineRecognizer *recognizer =
|
||||
SherpaOnnxCreateOnlineRecognizer(&recognizer_config);
|
||||
|
||||
free(tokens_buf);
|
||||
tokens_buf = NULL;
|
||||
free(hotwords_buf);
|
||||
hotwords_buf = NULL;
|
||||
|
||||
if (recognizer == NULL) {
|
||||
fprintf(stderr, "Please check your config!\n");
|
||||
SherpaOnnxFreeWave(wave);
|
||||
return -1;
|
||||
}
|
||||
|
||||
SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer);
|
||||
|
||||
const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50);
|
||||
int32_t segment_id = 0;
|
||||
|
||||
// simulate streaming. You can choose an arbitrary N
|
||||
#define N 3200
|
||||
|
||||
fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n",
|
||||
wave->sample_rate, wave->num_samples,
|
||||
(float)wave->num_samples / wave->sample_rate);
|
||||
|
||||
int32_t k = 0;
|
||||
while (k < wave->num_samples) {
|
||||
int32_t start = k;
|
||||
int32_t end =
|
||||
(start + N > wave->num_samples) ? wave->num_samples : (start + N);
|
||||
k += N;
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate,
|
||||
wave->samples + start, end - start);
|
||||
while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) {
|
||||
SherpaOnnxDecodeOnlineStream(recognizer, stream);
|
||||
}
|
||||
|
||||
const SherpaOnnxOnlineRecognizerResult *r =
|
||||
SherpaOnnxGetOnlineStreamResult(recognizer, stream);
|
||||
|
||||
if (strlen(r->text)) {
|
||||
SherpaOnnxPrint(display, segment_id, r->text);
|
||||
}
|
||||
|
||||
if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) {
|
||||
if (strlen(r->text)) {
|
||||
++segment_id;
|
||||
}
|
||||
SherpaOnnxOnlineStreamReset(recognizer, stream);
|
||||
}
|
||||
|
||||
SherpaOnnxDestroyOnlineRecognizerResult(r);
|
||||
}
|
||||
|
||||
// add some tail padding
|
||||
float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||
4800);
|
||||
|
||||
SherpaOnnxFreeWave(wave);
|
||||
|
||||
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||
while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) {
|
||||
SherpaOnnxDecodeOnlineStream(recognizer, stream);
|
||||
}
|
||||
|
||||
const SherpaOnnxOnlineRecognizerResult *r =
|
||||
SherpaOnnxGetOnlineStreamResult(recognizer, stream);
|
||||
|
||||
if (strlen(r->text)) {
|
||||
SherpaOnnxPrint(display, segment_id, r->text);
|
||||
}
|
||||
|
||||
SherpaOnnxDestroyOnlineRecognizerResult(r);
|
||||
|
||||
SherpaOnnxDestroyDisplay(display);
|
||||
SherpaOnnxDestroyOnlineStream(stream);
|
||||
SherpaOnnxDestroyOnlineRecognizer(recognizer);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
|
||||
recognizer_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
if (config->model_config.tokens_buf &&
|
||||
config->model_config.tokens_buf_size > 0) {
|
||||
recognizer_config.model_config.tokens_buf = std::string(
|
||||
config->model_config.tokens_buf, config->model_config.tokens_buf_size);
|
||||
}
|
||||
|
||||
recognizer_config.model_config.num_threads =
|
||||
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.provider_config.provider =
|
||||
@@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
|
||||
recognizer_config.hotwords_score =
|
||||
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
|
||||
if (config->hotwords_buf && config->hotwords_buf_size > 0) {
|
||||
recognizer_config.hotwords_buf =
|
||||
std::string(config->hotwords_buf, config->hotwords_buf_size);
|
||||
}
|
||||
|
||||
recognizer_config.blank_penalty = config->blank_penalty;
|
||||
|
||||
|
||||
@@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
|
||||
// - cjkchar+bpe
|
||||
const char *modeling_unit;
|
||||
const char *bpe_vocab;
|
||||
/// if non-null, loading the tokens from the buffered string directly in
|
||||
/// prioriy
|
||||
const char *tokens_buf;
|
||||
/// byte size excluding the tailing '\0'
|
||||
int32_t tokens_buf_size;
|
||||
} SherpaOnnxOnlineModelConfig;
|
||||
|
||||
/// It expects 16 kHz 16-bit single channel wave format.
|
||||
@@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
|
||||
const char *rule_fsts;
|
||||
const char *rule_fars;
|
||||
float blank_penalty;
|
||||
|
||||
/// if non-nullptr, loading the hotwords from the buffered string directly in
|
||||
const char *hotwords_buf;
|
||||
/// byte size excluding the tailing '\0'
|
||||
int32_t hotwords_buf_size;
|
||||
} SherpaOnnxOnlineRecognizerConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
|
||||
|
||||
@@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
|
||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"you can not provide a tokens_buf and a tokens file: '%s', "
|
||||
"at the same time, which is confusing",
|
||||
tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tokens_buf.empty() && !FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"tokens: '%s' does not exist, you should provide "
|
||||
"either a tokens buffer or a tokens file",
|
||||
tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,11 @@ struct OnlineModelConfig {
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
|
||||
/// if tokens_buf is non-empty,
|
||||
/// the tokens will be loaded from the buffered string instead of from the
|
||||
/// ${tokens} file
|
||||
std::string tokens_buf;
|
||||
|
||||
OnlineModelConfig() = default;
|
||||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
||||
const OnlineParaformerModelConfig ¶former,
|
||||
@@ -53,8 +58,7 @@ struct OnlineModelConfig {
|
||||
const OnlineNeMoCtcModelConfig &nemo_ctc,
|
||||
const ProviderConfig &provider_config,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug,
|
||||
const std::string &model_type,
|
||||
int32_t warm_up, bool debug, const std::string &model_type,
|
||||
const std::string &modeling_unit,
|
||||
const std::string &bpe_vocab)
|
||||
: transducer(transducer),
|
||||
|
||||
@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
if (!config_.hotwords_buf.empty()) {
|
||||
InitHotwordsFromBufStr();
|
||||
} else if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitHotwordsFromBufStr() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::istringstream iss(config_.hotwords_buf);
|
||||
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
|
||||
@@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
|
||||
const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config),
|
||||
model_(
|
||||
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
symbol_table_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
symbol_table_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
|
||||
@@ -106,6 +106,11 @@ struct OnlineRecognizerConfig {
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
|
||||
/// used only for modified_beam_search, if hotwords_buf is non-empty,
|
||||
/// the hotwords will be loaded from the buffered string instead of from
|
||||
/// ${hotwords_file}
|
||||
std::string hotwords_buf;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(
|
||||
|
||||
@@ -20,9 +20,14 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
SymbolTable::SymbolTable(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
|
||||
if (is_file) {
|
||||
std::ifstream is(filename);
|
||||
Init(is);
|
||||
} else {
|
||||
std::istringstream iss(filename);
|
||||
Init(iss);
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
|
||||
@@ -19,13 +19,13 @@ namespace sherpa_onnx {
|
||||
class SymbolTable {
|
||||
public:
|
||||
SymbolTable() = default;
|
||||
/// Construct a symbol table from a file.
|
||||
/// Construct a symbol table from a file or from a buffered string.
|
||||
/// Each line in the file contains two fields:
|
||||
///
|
||||
/// sym ID
|
||||
///
|
||||
/// Fields are separated by space(s).
|
||||
explicit SymbolTable(const std::string &filename);
|
||||
explicit SymbolTable(const std::string &filename, bool is_file = true);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SymbolTable(AAssetManager *mgr, const std::string &filename);
|
||||
|
||||
Reference in New Issue
Block a user