Support whisper models (#238)

This commit is contained in:
Fangjun Kuang
2023-08-07 12:34:18 +08:00
committed by GitHub
parent 64efbd82af
commit 45b9d4ab37
39 changed files with 1836 additions and 52 deletions

View File

@@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON)
endif()
set(sources
base64-decode.cc
cat.cc
context-graph.cc
endpoint.cc
@@ -35,6 +36,9 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
online-conformer-transducer-model.cc
online-lm-config.cc
online-lm.cc
@@ -50,12 +54,12 @@ set(sources
online-zipformer-transducer-model.cc
online-zipformer2-transducer-model.cc
onnx-utils.cc
session.cc
packed-sequence.cc
pad-sequence.cc
parse-options.cc
provider.cc
resample.cc
session.cc
slice.cc
stack.cc
symbol-table.cc

View File

@@ -0,0 +1,67 @@
// sherpa-onnx/csrc/base64-decode.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
static int32_t Ord(char c) {
if (c >= 'A' && c <= 'Z') {
return c - 'A';
} else if (c >= 'a' && c <= 'z') {
return c - 'a' + ('Z' - 'A') + 1;
} else if (c >= '0' && c <= '9') {
return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
} else if (c == '+') {
return 62;
} else if (c == '/') {
return 63;
}
SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c);
exit(-1);
}
// see
// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243
std::string Base64Decode(const std::string &s) {
if (s.empty()) {
SHERPA_ONNX_LOGE("Empty string!");
exit(-1);
}
int32_t n = s.size() / 4 * 3;
std::string ans;
ans.reserve(n);
int32_t i = 0;
while (i < static_cast<int32_t>(s.size())) {
if (s[i] == '=') {
return " ";
}
int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
ans.push_back(first);
if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
int32_t second =
((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
ans.push_back(second);
if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
ans.push_back(third);
}
}
i += 4;
}
return ans;
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,19 @@
// sherpa-onnx/csrc/base64-decode.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_
#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_
#include <string>
namespace sherpa_onnx {
/** @param s A base64 encoded string.
* @return Return the decoded string.
*/
std::string Base64Decode(const std::string &s);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_

View File

@@ -1,4 +1,3 @@
// sherpa-onnx/csrc/macros.h
//
// Copyright 2023 Xiaomi Corporation

View File

@@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
nemo_ctc.Register(po);
whisper.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper."
"All other values lead to loading the model twice.");
}
@@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const {
return nemo_ctc.Validate();
}
if (!whisper.encoder.empty()) {
return whisper.Validate();
}
return transducer.Validate();
}
@@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const {
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "whisper=" << whisper.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";

View File

@@ -9,6 +9,7 @@
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
@@ -16,6 +17,7 @@ struct OfflineModelConfig {
OfflineTransducerModelConfig transducer;
OfflineParaformerModelConfig paraformer;
OfflineNemoEncDecCtcModelConfig nemo_ctc;
OfflineWhisperModelConfig whisper;
std::string tokens;
int32_t num_threads = 2;
@@ -37,11 +39,13 @@ struct OfflineModelConfig {
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const OfflineWhisperModelConfig &whisper,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
whisper(whisper),
tokens(tokens),
num_threads(num_threads),
debug(debug),

View File

@@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
return false;
}

View File

@@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
bool OfflineParaformerModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
return false;
}

View File

@@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
@@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
@@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
SHERPA_ONNX_LOGE("Please provide a model");
exit(-1);
@@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://huggingface.co/csukuangfj/"
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n "
"(3) Whisper"
"\n");
exit(-1);
}
@@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
}
SHERPA_ONNX_LOGE(
"\nUnsupported model_type: %s\n"
"We support only the following model types at present: \n"
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n",
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n",
model_type.c_str());
exit(-1);

View File

@@ -0,0 +1,152 @@
// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
for (auto i : src.tokens) {
if (!sym_table.contains(i)) {
continue;
}
const auto &s = sym_table[i];
r.text += s;
r.tokens.push_back(s);
}
return r;
}
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
// tokens.txt from whisper is base64 encoded, so we need to decode it
symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
} else {
SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s",
config.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
// batch decoding is not implemented yet
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
private:
void DecodeStream(OfflineStream *s) const {
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
if (num_frames > max_num_frames) {
SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported.");
exit(-1);
}
NormalizeFeatures(f.data(), num_frames, feat_dim);
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.begin(), f.end(), p_mel);
memset(p_mel + f.size(), 0,
(max_num_frames - num_frames) * feat_dim * sizeof(float));
mel = Transpose12(model_->Allocator(), &mel);
auto cross_kv = model_->ForwardEncoder(std::move(mel));
auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
auto r = Convert(results[0], symbol_table_);
s->SetResult(r);
}
private:
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineWhisperModel> model_;
std::unique_ptr<OfflineWhisperDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_

View File

@@ -86,6 +86,15 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
: context_graph_(context_graph) {
config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80;
whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -117,20 +126,35 @@ class OfflineStream::Impl {
lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
if (fbank_) {
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
samples.data(), samples.size());
whisper_fbank_->InputFinished();
}
return;
} // if (sampling_rate != opts_.frame_opts.samp_freq)
if (fbank_) {
fbank_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->InputFinished();
} else {
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
whisper_fbank_->InputFinished();
}
}
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
std::vector<float> GetFrames() const {
int32_t n = fbank_->NumFramesReady();
int32_t n =
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
assert(n > 0 && "Please first call AcceptWaveform()");
int32_t feature_dim = FeatureDim();
@@ -140,7 +164,8 @@ class OfflineStream::Impl {
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i);
const float *f =
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
@@ -191,6 +216,7 @@ class OfflineStream::Impl {
private:
OfflineFeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::~OfflineStream() = default;
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,

View File

@@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig {
void Register(ParseOptions *po);
};
struct WhisperTag {};
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
explicit OfflineStream(WhisperTag tag,
ContextGraphPtr context_graph = nullptr);
~OfflineStream();
/**

View File

@@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
joiner_filename.c_str());
return false;
}

View File

@@ -0,0 +1,38 @@
// sherpa-onnx/csrc/offline-whisper-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineWhisperDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
};
class OfflineWhisperDecoder {
public:
virtual ~OfflineWhisperDecoder() = default;
/** Run beam search given the output from the whisper encoder model.
*
* @param n_layer_cross_k A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param n_layer_cross_v A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_

View File

@@ -0,0 +1,93 @@
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
namespace sherpa_onnx {
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
int32_t batch_size = 1;
std::array<int64_t, 2> token_shape{
batch_size, static_cast<int64_t>(initial_tokens.size())};
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, initial_tokens.data(), initial_tokens.size(),
token_shape.data(), token_shape.size());
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
int32_t vocab_size = logits_shape[2];
int32_t max_token_id = static_cast<int32_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
int32_t n_text_ctx = model_->TextCtx();
std::vector<int32_t> predicted_tokens;
for (int32_t i = 0; i < n_text_ctx; ++i) {
if (max_token_id == model_->EOT()) {
break;
}
predicted_tokens.push_back(max_token_id);
std::array<int64_t, 2> token_shape{1, 1};
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), token_shape.data(), token_shape.size());
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
p_tokens[0] = max_token_id;
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
if (i == 0) {
*p_offset = initial_tokens.size();
} else {
*p_offset += 1;
}
decoder_out = model_->ForwardDecoder(std::move(tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
std::move(std::get<3>(decoder_out)),
std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out)));
const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>();
max_token_id = static_cast<int64_t>(std::distance(
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
}
std::vector<OfflineWhisperDecoderResult> ans(1);
ans[0].tokens = std::move(predicted_tokens);
return ans;
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace sherpa_onnx {
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
public:
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
: model_(model) {}
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
private:
OfflineWhisperModel *model_; // not owned
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_

View File

@@ -0,0 +1,46 @@
// sherpa-onnx/csrc/offline-whisper-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register("whisper-encoder", &encoder,
"Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, "
"medium.en-encoder.onnx.");
po->Register("whisper-decoder", &decoder,
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
"medium.en-decoder.onnx.");
}
bool OfflineWhisperModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string OfflineWhisperModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineWhisperModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,30 @@
// sherpa-onnx/csrc/offline-whisper-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineWhisperModelConfig {
std::string encoder;
std::string decoder;
OfflineWhisperModelConfig() = default;
OfflineWhisperModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_

View File

@@ -0,0 +1,213 @@
// sherpa-onnx/csrc/offline-whisper-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineWhisperModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
auto encoder_out = encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), &features, 1,
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) {
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache),
std::move(n_layer_cross_k),
std::move(n_layer_cross_v),
std::move(offset)};
auto decoder_out = decoder_sess_->Run(
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
decoder_input.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
return {std::move(decoder_out[0]), std::move(decoder_out[1]),
std::move(decoder_out[2]), std::move(decoder_input[3]),
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
Allocator(), shape.data(), shape.size());
auto n = shape[0] * shape[1] * shape[2] * shape[3];
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
memset(p_k, 0, sizeof(float) * n);
memset(p_v, 0, sizeof(float) * n);
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
}
OrtAllocator *Allocator() const { return allocator_; }
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
int32_t EOT() const { return eot_; }
int32_t TextCtx() const { return n_text_ctx_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
// model meta data
int32_t n_text_layer_;
int32_t n_text_ctx_;
int32_t n_text_state_;
int32_t sot_;
int32_t eot_;
int32_t blank_;
int32_t translate_;
int32_t no_timestamps_;
int32_t no_speech_;
std::vector<int64_t> sot_sequence_;
};
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
Ort::Value features) {
return impl_->ForwardEncoder(std::move(features));
}
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache,
Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v,
Ort::Value offset) {
return impl_->ForwardDecoder(
std::move(tokens), std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
std::move(n_layer_cross_v), std::move(offset));
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
return impl_->GetInitialSelfKVCache();
}
OrtAllocator *OfflineWhisperModel::Allocator() const {
return impl_->Allocator();
}
const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
return impl_->GetInitialTokens();
}
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
} // namespace sherpa_onnx

View File

@@ -0,0 +1,85 @@
// sherpa-onnx/csrc/offline-whisper-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
~OfflineWhisperModel();
/** Run the encoder model.
*
* @param features A tensor of shape (N, C, T). It is changed in-place.
* C is 80 and T is 3000.
*
* @return Return a pair containing:
* - n_layer_cross_k: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
* - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
*/
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
/** Run the decoder model.
*
* @param tokens A int64 tensor of shape (N, num_words)
* @param n_layer_self_k_cache A 4-D tensor of shape
* (n_text_layer, N, n_text_ctx, n_text_state).
* @param n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_text_ctx, n_text_state).
* @param n_layer_cross_k A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param n_layer_cross_v A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* @param offset A int64 tensor of shape (N,)
*
* @return Return a tuple containing 6 tensors:
*
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
* - out_n_layer_cross_k Same as n_layer_cross_k
* - out_n_layer_cross_v Same as n_layer_cross_v
* - out_offset Same as offset
*/
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset);
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
* - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*/
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
const std::vector<int64_t> &GetInitialTokens() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
int32_t EOT() const;
int32_t TextCtx() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_

View File

@@ -98,11 +98,15 @@ Usage:
./bin/sherpa-onnx-microphone-offline \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search
--num-threads=1
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search.
(3) Whisper models
./bin/sherpa-onnx-microphone-offline \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html

View File

@@ -23,7 +23,7 @@ Usage:
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--num-threads=1 \
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
@@ -33,14 +33,22 @@ Usage:
./bin/sherpa-onnx-offline \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/model.onnx \
--num-threads=2 \
--num-threads=1 \
--decoding-method=greedy_search \
/path/to/foo.wav [bar.wav foobar.wav ...]
(3) Whisper models
./bin/sherpa-onnx-offline \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1 \
/path/to/foo.wav [bar.wav foobar.wav ...]
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
@@ -55,6 +63,7 @@ for a list of pre-trained models to download.
po.Read(argc, argv);
if (po.NumArgs() < 1) {
fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}

View File

@@ -9,6 +9,7 @@
#include <sstream>
#include <strstream>
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
@@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
return os << symbol_table.ToString();
}
void SymbolTable::ApplyBase64Decode() {
sym2id_.clear();
for (auto &p : id2sym_) {
p.second = Base64Decode(p.second);
sym2id_[p.second] = p.first;
}
}
} // namespace sherpa_onnx

View File

@@ -45,6 +45,9 @@ class SymbolTable {
/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;
// for tokens.txt from Whisper
void ApplyBase64Decode();
private:
void Init(std::istream &is);

View File

@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
offline-recognizer.cc
offline-stream.cc
offline-transducer-model-config.cc
offline-whisper-model-config.cc
online-lm-config.cc
online-recognizer.cc
online-stream.cc

View File

@@ -11,6 +11,7 @@
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
@@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineTransducerModelConfig(m);
PybindOfflineParaformerModelConfig(m);
PybindOfflineNemoEncDecCtcModelConfig(m);
PybindOfflineWhisperModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(
py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)

View File

@@ -0,0 +1,24 @@
// sherpa-onnx/python/csrc/offline-whisper-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m) {
using PyClass = OfflineWhisperModelConfig;
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
.def(py::init<const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"))
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-whisper-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_

View File

@@ -1,4 +1,5 @@
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
from pathlib import Path
from typing import List, Optional
@@ -7,6 +8,7 @@ from _sherpa_onnx import (
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineWhisperModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
@@ -69,7 +71,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Support only greedy_search for now.
Valid values: greedy_search, modified_beam_search.
debug:
True to show debug messages.
provider:
@@ -137,7 +139,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
@@ -185,14 +187,14 @@ class OfflineRecognizer(object):
English, etc.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
@@ -200,7 +202,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
@@ -229,6 +231,68 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
@classmethod
def from_whisper(
cls,
encoder: str,
decoder: str,
tokens: str,
num_threads: int,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different kinds of whisper models,
e.g., tiny, tiny.en, base, base.en, etc.
Args:
encoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
decoder_model:
Path to the encoder model, e.g., tiny-encoder.onnx,
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
decoding_method:
Valid values: greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="whisper",
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=16000,
feature_dim=80,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()