Support whisper models (#238)
This commit is contained in:
@@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||
endif()
|
||||
|
||||
set(sources
|
||||
base64-decode.cc
|
||||
cat.cc
|
||||
context-graph.cc
|
||||
endpoint.cc
|
||||
@@ -35,6 +36,9 @@ set(sources
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
offline-whisper-greedy-search-decoder.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-whisper-model.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-lm-config.cc
|
||||
online-lm.cc
|
||||
@@ -50,12 +54,12 @@ set(sources
|
||||
online-zipformer-transducer-model.cc
|
||||
online-zipformer2-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
session.cc
|
||||
packed-sequence.cc
|
||||
pad-sequence.cc
|
||||
parse-options.cc
|
||||
provider.cc
|
||||
resample.cc
|
||||
session.cc
|
||||
slice.cc
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
|
||||
67
sherpa-onnx/csrc/base64-decode.cc
Normal file
67
sherpa-onnx/csrc/base64-decode.cc
Normal file
@@ -0,0 +1,67 @@
|
||||
// sherpa-onnx/csrc/base64-decode.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/base64-decode.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static int32_t Ord(char c) {
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
return c - 'A';
|
||||
} else if (c >= 'a' && c <= 'z') {
|
||||
return c - 'a' + ('Z' - 'A') + 1;
|
||||
} else if (c >= '0' && c <= '9') {
|
||||
return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
|
||||
} else if (c == '+') {
|
||||
return 62;
|
||||
} else if (c == '/') {
|
||||
return 63;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c);
|
||||
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// see
|
||||
// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243
|
||||
std::string Base64Decode(const std::string &s) {
|
||||
if (s.empty()) {
|
||||
SHERPA_ONNX_LOGE("Empty string!");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t n = s.size() / 4 * 3;
|
||||
|
||||
std::string ans;
|
||||
ans.reserve(n);
|
||||
|
||||
int32_t i = 0;
|
||||
while (i < static_cast<int32_t>(s.size())) {
|
||||
if (s[i] == '=') {
|
||||
return " ";
|
||||
}
|
||||
|
||||
int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
|
||||
ans.push_back(first);
|
||||
|
||||
if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
|
||||
int32_t second =
|
||||
((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
|
||||
ans.push_back(second);
|
||||
|
||||
if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
|
||||
int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
|
||||
ans.push_back(third);
|
||||
}
|
||||
}
|
||||
i += 4;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
19
sherpa-onnx/csrc/base64-decode.h
Normal file
19
sherpa-onnx/csrc/base64-decode.h
Normal file
@@ -0,0 +1,19 @@
|
||||
// sherpa-onnx/csrc/base64-decode.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||
#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** @param s A base64 encoded string.
|
||||
* @return Return the decoded string.
|
||||
*/
|
||||
std::string Base64Decode(const std::string &s);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// sherpa-onnx/csrc/macros.h
|
||||
//
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
|
||||
@@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
transducer.Register(po);
|
||||
paraformer.Register(po);
|
||||
nemo_ctc.Register(po);
|
||||
whisper.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
@@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper."
|
||||
"All other values lead to loading the model twice.");
|
||||
}
|
||||
|
||||
@@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return nemo_ctc.Validate();
|
||||
}
|
||||
|
||||
if (!whisper.encoder.empty()) {
|
||||
return whisper.Validate();
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
@@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "transducer=" << transducer.ToString() << ", ";
|
||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||
os << "whisper=" << whisper.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -16,6 +17,7 @@ struct OfflineModelConfig {
|
||||
OfflineTransducerModelConfig transducer;
|
||||
OfflineParaformerModelConfig paraformer;
|
||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||
OfflineWhisperModelConfig whisper;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
@@ -37,11 +39,13 @@ struct OfflineModelConfig {
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const OfflineWhisperModelConfig &whisper,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
whisper(whisper),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
|
||||
@@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool OfflineParaformerModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
@@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
@@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.paraformer.model;
|
||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please provide a model");
|
||||
exit(-1);
|
||||
@@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n "
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||
"\n "
|
||||
"(3) Whisper"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\nUnsupported model_type: %s\n"
|
||||
"We support only the following model types at present: \n"
|
||||
" - Non-streaming transducer models from icefall\n"
|
||||
" - Non-streaming Paraformer models from FunASR\n"
|
||||
" - EncDecCTCModelBPE models from NeMo\n",
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
|
||||
152
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Normal file
152
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Normal file
@@ -0,0 +1,152 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
r.text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
|
||||
// tokens.txt from whisper is base64 encoded, so we need to decode it
|
||||
symbol_table_.ApplyBase64Decode();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for whisper. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
// batch decoding is not implemented yet
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
DecodeStream(ss[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
int32_t max_num_frames = 3000;
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = s->FeatureDim();
|
||||
std::vector<float> f = s->GetFrames();
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
if (num_frames > max_num_frames) {
|
||||
SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
|
||||
|
||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||
model_->Allocator(), shape.data(), shape.size());
|
||||
float *p_mel = mel.GetTensorMutableData<float>();
|
||||
std::copy(f.begin(), f.end(), p_mel);
|
||||
|
||||
memset(p_mel + f.size(), 0,
|
||||
(max_num_frames - num_frames) * feat_dim * sizeof(float));
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
auto results =
|
||||
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
private:
|
||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim) {
|
||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
// mel = (log_spec + 4.0) / 4.0
|
||||
|
||||
int32_t n = num_frames * feat_dim;
|
||||
float max_v = -1e20;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
|
||||
f = std::max<float>(f, 1e-10);
|
||||
f = std::log10(f);
|
||||
|
||||
max_v = std::max(f, max_v);
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
|
||||
max_v -= 8;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
f = std::max(f, max_v);
|
||||
|
||||
f = (f + 4) / 4;
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineWhisperModel> model_;
|
||||
std::unique_ptr<OfflineWhisperDecoder> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||
@@ -86,6 +86,15 @@ class OfflineStream::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
||||
: context_graph_(context_graph) {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = 80;
|
||||
whisper_fbank_ =
|
||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
@@ -117,20 +126,35 @@ class OfflineStream::Impl {
|
||||
lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
return;
|
||||
}
|
||||
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
|
||||
samples.data(), samples.size());
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
|
||||
return;
|
||||
} // if (sampling_rate != opts_.frame_opts.samp_freq)
|
||||
|
||||
if (fbank_) {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
} else {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
whisper_fbank_->InputFinished();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
|
||||
std::vector<float> GetFrames() const {
|
||||
int32_t n = fbank_->NumFramesReady();
|
||||
int32_t n =
|
||||
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
|
||||
|
||||
assert(n > 0 && "Please first call AcceptWaveform()");
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
@@ -140,7 +164,8 @@ class OfflineStream::Impl {
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f = fbank_->GetFrame(i);
|
||||
const float *f =
|
||||
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
@@ -191,6 +216,7 @@ class OfflineStream::Impl {
|
||||
private:
|
||||
OfflineFeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
|
||||
@@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig {
|
||||
void Register(ParseOptions *po);
|
||||
};
|
||||
|
||||
struct WhisperTag {};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
|
||||
explicit OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
|
||||
@@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool OfflineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
|
||||
encoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
|
||||
decoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(joiner_filename)) {
|
||||
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
|
||||
joiner_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
38
sherpa-onnx/csrc/offline-whisper-decoder.h
Normal file
38
sherpa-onnx/csrc/offline-whisper-decoder.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineWhisperDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int32_t> tokens;
|
||||
};
|
||||
|
||||
class OfflineWhisperDecoder {
|
||||
public:
|
||||
virtual ~OfflineWhisperDecoder() = default;
|
||||
|
||||
/** Run beam search given the output from the whisper encoder model.
|
||||
*
|
||||
* @param n_layer_cross_k A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
* @param n_layer_cross_v A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
*
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
||||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||
93
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Normal file
93
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,93 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
||||
int32_t batch_size = 1;
|
||||
std::array<int64_t, 2> token_shape{
|
||||
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, initial_tokens.data(), initial_tokens.size(),
|
||||
token_shape.data(), token_shape.size());
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t vocab_size = logits_shape[2];
|
||||
|
||||
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
||||
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
||||
|
||||
int32_t n_text_ctx = model_->TextCtx();
|
||||
|
||||
std::vector<int32_t> predicted_tokens;
|
||||
for (int32_t i = 0; i < n_text_ctx; ++i) {
|
||||
if (max_token_id == model_->EOT()) {
|
||||
break;
|
||||
}
|
||||
|
||||
predicted_tokens.push_back(max_token_id);
|
||||
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), token_shape.data(), token_shape.size());
|
||||
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
||||
p_tokens[0] = max_token_id;
|
||||
|
||||
int64_t *p_offset =
|
||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||
|
||||
if (i == 0) {
|
||||
*p_offset = initial_tokens.size();
|
||||
} else {
|
||||
*p_offset += 1;
|
||||
}
|
||||
|
||||
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
||||
std::move(std::get<1>(decoder_out)),
|
||||
std::move(std::get<2>(decoder_out)),
|
||||
std::move(std::get<3>(decoder_out)),
|
||||
std::move(std::get<4>(decoder_out)),
|
||||
std::move(std::get<5>(decoder_out)));
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
max_token_id = static_cast<int64_t>(std::distance(
|
||||
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||
ans[0].tokens = std::move(predicted_tokens);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
29
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Normal file
29
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||
public:
|
||||
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
|
||||
: model_(model) {}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) override;
|
||||
|
||||
private:
|
||||
OfflineWhisperModel *model_; // not owned
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||
46
sherpa-onnx/csrc/offline-whisper-model-config.cc
Normal file
46
sherpa-onnx/csrc/offline-whisper-model-config.cc
Normal file
@@ -0,0 +1,46 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("whisper-encoder", &encoder,
|
||||
"Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, "
|
||||
"medium.en-encoder.onnx.");
|
||||
|
||||
po->Register("whisper-decoder", &decoder,
|
||||
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
||||
"medium.en-decoder.onnx.");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineWhisperModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineWhisperModelConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
30
sherpa-onnx/csrc/offline-whisper-model-config.h
Normal file
30
sherpa-onnx/csrc/offline-whisper-model-config.h
Normal file
@@ -0,0 +1,30 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineWhisperModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
OfflineWhisperModelConfig() = default;
|
||||
OfflineWhisperModelConfig(const std::string &encoder,
|
||||
const std::string &decoder)
|
||||
: encoder(encoder), decoder(decoder) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
213
sherpa-onnx/csrc/offline-whisper-model.cc
Normal file
213
sherpa-onnx/csrc/offline-whisper-model.cc
Normal file
@@ -0,0 +1,213 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-model.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWhisperModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), &features, 1,
|
||||
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||
|
||||
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
|
||||
}
|
||||
|
||||
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||
Ort::Value>
|
||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) {
|
||||
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
|
||||
std::move(n_layer_self_k_cache),
|
||||
std::move(n_layer_self_v_cache),
|
||||
std::move(n_layer_cross_k),
|
||||
std::move(n_layer_cross_v),
|
||||
std::move(offset)};
|
||||
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
|
||||
decoder_input.size(), decoder_output_names_ptr_.data(),
|
||||
decoder_output_names_ptr_.size());
|
||||
|
||||
return {std::move(decoder_out[0]), std::move(decoder_out[1]),
|
||||
std::move(decoder_out[2]), std::move(decoder_input[3]),
|
||||
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
||||
|
||||
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
|
||||
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
|
||||
auto n = shape[0] * shape[1] * shape[2] * shape[3];
|
||||
|
||||
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
|
||||
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
|
||||
|
||||
memset(p_k, 0, sizeof(float) * n);
|
||||
memset(p_v, 0, sizeof(float) * n);
|
||||
|
||||
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||
|
||||
int32_t EOT() const { return eot_; }
|
||||
|
||||
int32_t TextCtx() const { return n_text_ctx_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
||||
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
||||
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
||||
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
||||
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
||||
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
// model meta data
|
||||
int32_t n_text_layer_;
|
||||
int32_t n_text_ctx_;
|
||||
int32_t n_text_state_;
|
||||
int32_t sot_;
|
||||
int32_t eot_;
|
||||
int32_t blank_;
|
||||
int32_t translate_;
|
||||
int32_t no_timestamps_;
|
||||
int32_t no_speech_;
|
||||
std::vector<int64_t> sot_sequence_;
|
||||
};
|
||||
|
||||
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||
Ort::Value features) {
|
||||
return impl_->ForwardEncoder(std::move(features));
|
||||
}
|
||||
|
||||
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||
Ort::Value>
|
||||
OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||
Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache,
|
||||
Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v,
|
||||
Ort::Value offset) {
|
||||
return impl_->ForwardDecoder(
|
||||
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||
std::move(n_layer_cross_v), std::move(offset));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
|
||||
return impl_->GetInitialSelfKVCache();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineWhisperModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
|
||||
return impl_->GetInitialTokens();
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
||||
|
||||
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
85
sherpa-onnx/csrc/offline-whisper-model.h
Normal file
85
sherpa-onnx/csrc/offline-whisper-model.h
Normal file
@@ -0,0 +1,85 @@
|
||||
// sherpa-onnx/csrc/offline-whisper-model.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWhisperModel {
|
||||
public:
|
||||
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||
~OfflineWhisperModel();
|
||||
|
||||
/** Run the encoder model.
|
||||
*
|
||||
* @param features A tensor of shape (N, C, T). It is changed in-place.
|
||||
* C is 80 and T is 3000.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - n_layer_cross_k: A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||
* - n_layer_cross_v: A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
* @param tokens A int64 tensor of shape (N, num_words)
|
||||
* @param n_layer_self_k_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_text_ctx, n_text_state).
|
||||
* @param n_layer_self_v_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_text_ctx, n_text_state).
|
||||
* @param n_layer_cross_k A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
* @param n_layer_cross_v A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
* @param offset A int64 tensor of shape (N,)
|
||||
*
|
||||
* @return Return a tuple containing 6 tensors:
|
||||
*
|
||||
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
|
||||
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
|
||||
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
|
||||
* - out_n_layer_cross_k Same as n_layer_cross_k
|
||||
* - out_n_layer_cross_v Same as n_layer_cross_v
|
||||
* - out_offset Same as offset
|
||||
*/
|
||||
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||
Ort::Value>
|
||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset);
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
||||
const std::vector<int64_t> &GetInitialTokens() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
int32_t EOT() const;
|
||||
int32_t TextCtx() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
@@ -98,11 +98,15 @@ Usage:
|
||||
./bin/sherpa-onnx-microphone-offline \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search
|
||||
--num-threads=1
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search.
|
||||
(3) Whisper models
|
||||
|
||||
./bin/sherpa-onnx-microphone-offline \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--num-threads=1
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
|
||||
@@ -23,7 +23,7 @@ Usage:
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--num-threads=1 \
|
||||
--decoding-method=greedy_search \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
@@ -33,14 +33,22 @@ Usage:
|
||||
./bin/sherpa-onnx-offline \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/model.onnx \
|
||||
--num-threads=2 \
|
||||
--num-threads=1 \
|
||||
--decoding-method=greedy_search \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
(3) Whisper models
|
||||
|
||||
./bin/sherpa-onnx-offline \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--num-threads=1 \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
|
||||
Note: It supports decoding multiple files in batches
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search.
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
@@ -55,6 +63,7 @@ for a list of pre-trained models to download.
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() < 1) {
|
||||
fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <sstream>
|
||||
#include <strstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/base64-decode.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
|
||||
return os << symbol_table.ToString();
|
||||
}
|
||||
|
||||
void SymbolTable::ApplyBase64Decode() {
|
||||
sym2id_.clear();
|
||||
for (auto &p : id2sym_) {
|
||||
p.second = Base64Decode(p.second);
|
||||
sym2id_[p.second] = p.first;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -45,6 +45,9 @@ class SymbolTable {
|
||||
/// Return true if there is a given symbol in the symbol table.
|
||||
bool contains(const std::string &sym) const;
|
||||
|
||||
// for tokens.txt from Whisper
|
||||
void ApplyBase64Decode();
|
||||
|
||||
private:
|
||||
void Init(std::istream &is);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
offline-transducer-model-config.cc
|
||||
offline-whisper-model-config.cc
|
||||
online-lm-config.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineTransducerModelConfig(m);
|
||||
PybindOfflineParaformerModelConfig(m);
|
||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||
PybindOfflineWhisperModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
.def(
|
||||
py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
.def(py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const OfflineWhisperModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
|
||||
24
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
Normal file
24
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
Normal file
@@ -0,0 +1,24 @@
|
||||
// sherpa-onnx/python/csrc/offline-whisper-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||
using PyClass = OfflineWhisperModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &>(),
|
||||
py::arg("encoder"), py::arg("decoder"))
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-whisper-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-whisper-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-whisper-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineWhisperModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2023 by manyeyes
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -7,6 +8,7 @@ from _sherpa_onnx import (
|
||||
OfflineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
OfflineWhisperModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
@@ -69,7 +71,7 @@ class OfflineRecognizer(object):
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
decoding_method:
|
||||
Support only greedy_search for now.
|
||||
Valid values: greedy_search, modified_beam_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -137,7 +139,7 @@ class OfflineRecognizer(object):
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
decoding_method:
|
||||
Valid values are greedy_search, modified_beam_search.
|
||||
Valid values are greedy_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -185,14 +187,14 @@ class OfflineRecognizer(object):
|
||||
English, etc.
|
||||
|
||||
Args:
|
||||
model:
|
||||
Path to ``model.onnx``.
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
model:
|
||||
Path to ``model.onnx``.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
@@ -200,7 +202,7 @@ class OfflineRecognizer(object):
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
decoding_method:
|
||||
Valid values are greedy_search, modified_beam_search.
|
||||
Valid values are greedy_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -229,6 +231,68 @@ class OfflineRecognizer(object):
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_whisper(
|
||||
cls,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
tokens: str,
|
||||
num_threads: int,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||
to download pre-trained models for different kinds of whisper models,
|
||||
e.g., tiny, tiny.en, base, base.en, etc.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||
decoder_model:
|
||||
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
decoding_method:
|
||||
Valid values: greedy_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="whisper",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
sampling_rate=16000,
|
||||
feature_dim=80,
|
||||
)
|
||||
|
||||
recognizer_config = OfflineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
|
||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||
if contexts_list is None:
|
||||
return self.recognizer.create_stream()
|
||||
|
||||
Reference in New Issue
Block a user