Support multilingual whisper models (#274)
This commit is contained in:
@@ -37,7 +37,7 @@
|
||||
} \
|
||||
\
|
||||
dst = atoi(value.get()); \
|
||||
if (dst <= 0) { \
|
||||
if (dst < 0) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
@@ -77,6 +77,24 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// read a vector of strings
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
SplitStringToVector(value.get(), ",", false, &dst); \
|
||||
\
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
|
||||
src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Read a string
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||
do { \
|
||||
|
||||
@@ -23,21 +23,227 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::string FixInvalidUtf8(const std::string &s) {
|
||||
int32_t s_size = s.size();
|
||||
|
||||
std::string ans;
|
||||
ans.reserve(s_size);
|
||||
|
||||
for (int32_t i = 0; i < s_size;) {
|
||||
uint8_t c = s[i];
|
||||
if (c < 0x80) {
|
||||
// valid
|
||||
ans.append(1, c);
|
||||
++i;
|
||||
continue;
|
||||
} else if ((c >= 0xc0) && (c < 0xe0)) {
|
||||
// beginning of two bytes
|
||||
if ((i + 1) > (s_size - 1)) {
|
||||
// no subsequent byte. invalid!
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
// valid 2-byte utf-8
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
i += 2;
|
||||
continue;
|
||||
} else if ((c >= 0xe0) && (c < 0xf0)) {
|
||||
// beginning of 3 bytes
|
||||
if ((i + 2) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
i += 3;
|
||||
continue;
|
||||
} else if ((c >= 0xf0) && (c < 0xf8)) {
|
||||
// 4 bytes
|
||||
if ((i + 3) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
i += 4;
|
||||
continue;
|
||||
} else if ((c >= 0xf8) && (c < 0xfc)) {
|
||||
// 5 bytes
|
||||
if ((i + 4) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
i += 5;
|
||||
continue;
|
||||
} else if ((c >= 0xfc) && (c < 0xfe)) {
|
||||
// 6 bytes
|
||||
if ((i + 5) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next5 = s[i + 5];
|
||||
if (!(next5 >= 0x80 && next5 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
ans.append(1, next5);
|
||||
i += 6;
|
||||
continue;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
r.text += s;
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Fix the following error in offline-stream.cc
|
||||
//
|
||||
// j["text"] = text;
|
||||
|
||||
// libc++abi: terminating with uncaught exception of type
|
||||
// nlohmann::json_abi_v3_11_2::detail::type_error:
|
||||
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
|
||||
|
||||
#if 0
|
||||
r.text = FixInvalidUtf8(text);
|
||||
#else
|
||||
r.text = text;
|
||||
#endif
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
symbol_table_.ApplyBase64Decode();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
||||
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
|
||||
config_.model_config.whisper, model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for whisper. Given %s",
|
||||
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
|
||||
auto results =
|
||||
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||
|
||||
|
||||
@@ -7,17 +7,106 @@
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||
int64_t token_val = model_->SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
model_->GetID2Lang().at(lang_id).c_str());
|
||||
#endif
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
// For multilingual models, initial_tokens contains [sot, language, task]
|
||||
// - language is English by default
|
||||
// - task is transcribe by default
|
||||
//
|
||||
// For non-multilingual models, initial_tokens contains [sot]
|
||||
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
||||
|
||||
if (model_->IsMultiLingual()) {
|
||||
if (!config_.language.empty()) {
|
||||
const auto &lang2id = model_->GetLang2ID();
|
||||
|
||||
if (!lang2id.count(config_.language)) {
|
||||
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t lang_id = lang2id.at(config_.language);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
} else {
|
||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
}
|
||||
|
||||
if (config_.task == "translate") {
|
||||
initial_tokens[2] = model_->Translate();
|
||||
} else if (config_.task != "transcribe") {
|
||||
// initial_tokens[2] is transcribe by default
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported task: %s. Valid values are: transcribe, translate.",
|
||||
config_.task.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
initial_tokens.push_back(model_->NoTimeStampsToken());
|
||||
|
||||
int32_t batch_size = 1;
|
||||
std::array<int64_t, 2> token_shape{
|
||||
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
||||
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
|
||||
initial_tokens.size();
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), token_shape.data(), token_shape.size());
|
||||
|
||||
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
||||
p_tokens[0] = max_token_id;
|
||||
|
||||
int64_t *p_offset =
|
||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||
|
||||
if (i == 0) {
|
||||
*p_offset = initial_tokens.size();
|
||||
} else {
|
||||
*p_offset += 1;
|
||||
}
|
||||
|
||||
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
||||
std::move(std::get<1>(decoder_out)),
|
||||
std::move(std::get<2>(decoder_out)),
|
||||
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
std::move(std::get<4>(decoder_out)),
|
||||
std::move(std::get<5>(decoder_out)));
|
||||
|
||||
int64_t *p_offset =
|
||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||
|
||||
*p_offset += 1;
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||
|
||||
ans[0].tokens = std::move(predicted_tokens);
|
||||
|
||||
return ans;
|
||||
|
||||
@@ -8,19 +8,25 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||
public:
|
||||
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
|
||||
: model_(model) {}
|
||||
OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
|
||||
OfflineWhisperModel *model)
|
||||
: config_(config), model_(model) {}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) override;
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) const; // NOLINT
|
||||
|
||||
private:
|
||||
OfflineWhisperModelConfig config_;
|
||||
OfflineWhisperModel *model_; // not owned
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("whisper-decoder", &decoder,
|
||||
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
||||
"medium.en-decoder.onnx.");
|
||||
|
||||
po->Register(
|
||||
"whisper-language", &language,
|
||||
"The spoke language in the input audio file. Example values: "
|
||||
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
|
||||
" infer the language from the input audio file. "
|
||||
"Please refer to "
|
||||
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
|
||||
" for valid values. Note that for non-multilingual models, it supports "
|
||||
"only 'en'");
|
||||
|
||||
po->Register("whisper-task", &task,
|
||||
"Valid values: transcribe, translate. "
|
||||
"Note that for non-multilingual models, it supports "
|
||||
"only 'transcribe'");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (task != "translate" && task != "transcribe") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--whisper-task supports only translate and transcribe. Given: %s",
|
||||
task.c_str());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
|
||||
|
||||
os << "OfflineWhisperModelConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\")";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "language=\"" << language << "\", ";
|
||||
os << "task=\"" << task << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// Available languages can be found at
|
||||
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
//
|
||||
// Note: For non-multilingual models, it supports only "en"
|
||||
//
|
||||
// If empty, we will infer it from the input audio file when
|
||||
// the model is multilingual.
|
||||
std::string language;
|
||||
|
||||
// Valid values are transcribe and translate
|
||||
//
|
||||
// Note: For non-multilingual models, it supports only "transcribe"
|
||||
std::string task = "transcribe";
|
||||
|
||||
OfflineWhisperModelConfig() = default;
|
||||
OfflineWhisperModelConfig(const std::string &encoder,
|
||||
const std::string &decoder)
|
||||
: encoder(encoder), decoder(decoder) {}
|
||||
const std::string &decoder,
|
||||
const std::string &language,
|
||||
const std::string &task)
|
||||
: encoder(encoder), decoder(decoder), language(language), task(task) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||
|
||||
const std::vector<int32_t> &GetAllLanguageIDs() const {
|
||||
return all_language_tokens_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
|
||||
return lang2id_;
|
||||
}
|
||||
|
||||
const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
|
||||
return id2lang_;
|
||||
}
|
||||
|
||||
int32_t NoTimeStampsToken() const { return no_timestamps_; }
|
||||
|
||||
int32_t EOT() const { return eot_; }
|
||||
|
||||
int32_t SOT() const { return sot_; }
|
||||
|
||||
int32_t TextCtx() const { return n_text_ctx_; }
|
||||
|
||||
int32_t VocabSize() const { return n_vocab_; }
|
||||
|
||||
int32_t Translate() const { return translate_; }
|
||||
|
||||
bool IsMultiLingual() const { return is_multilingual_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||
SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
|
||||
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
||||
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
||||
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
||||
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
||||
SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
|
||||
SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
|
||||
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
||||
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
||||
|
||||
if (is_multilingual_) {
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
|
||||
"all_language_tokens");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
|
||||
"all_language_codes");
|
||||
if (all_language_tokens_.size() != all_language_codes_.size()) {
|
||||
SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
|
||||
static_cast<int32_t>(all_language_tokens_.size()),
|
||||
static_cast<int32_t>(all_language_codes_.size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
|
||||
lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
|
||||
id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<int32_t> all_language_tokens_;
|
||||
std::vector<std::string> all_language_codes_;
|
||||
std::unordered_map<std::string, int32_t> lang2id_;
|
||||
std::unordered_map<int32_t, std::string> id2lang_;
|
||||
|
||||
// model meta data
|
||||
int32_t n_text_layer_;
|
||||
int32_t n_text_ctx_;
|
||||
int32_t n_text_state_;
|
||||
int32_t n_vocab_;
|
||||
int32_t sot_;
|
||||
int32_t eot_;
|
||||
int32_t blank_;
|
||||
int32_t translate_;
|
||||
int32_t transcribe_;
|
||||
int32_t no_timestamps_;
|
||||
int32_t no_speech_;
|
||||
int32_t is_multilingual_;
|
||||
std::vector<int64_t> sot_sequence_;
|
||||
};
|
||||
|
||||
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||
Ort::Value features) {
|
||||
Ort::Value features) const {
|
||||
return impl_->ForwardEncoder(std::move(features));
|
||||
}
|
||||
|
||||
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||
Ort::Value n_layer_self_v_cache,
|
||||
Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v,
|
||||
Ort::Value offset) {
|
||||
Ort::Value offset) const {
|
||||
return impl_->ForwardDecoder(
|
||||
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||
std::move(n_layer_cross_v), std::move(offset));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||
const {
|
||||
return impl_->GetInitialSelfKVCache();
|
||||
}
|
||||
|
||||
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
|
||||
return impl_->GetInitialTokens();
|
||||
}
|
||||
|
||||
const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
|
||||
return impl_->GetAllLanguageIDs();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int32_t>
|
||||
&OfflineWhisperModel::GetLang2ID() const {
|
||||
return impl_->GetLang2ID();
|
||||
}
|
||||
|
||||
const std::unordered_map<int32_t, std::string>
|
||||
&OfflineWhisperModel::GetID2Lang() const {
|
||||
return impl_->GetID2Lang();
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::NoTimeStampsToken() const {
|
||||
return impl_->NoTimeStampsToken();
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
||||
|
||||
int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
|
||||
|
||||
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
||||
|
||||
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
|
||||
|
||||
bool OfflineWhisperModel::IsMultiLingual() const {
|
||||
return impl_->IsMultiLingual();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
|
||||
* - n_layer_cross_v: A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
|
||||
Ort::Value>
|
||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset);
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||
|
||||
int32_t DetectLanguage() const;
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
|
||||
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
|
||||
const std::vector<int64_t> &GetInitialTokens() const;
|
||||
const std::vector<int32_t> &GetAllLanguageIDs() const;
|
||||
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
|
||||
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
int32_t NoTimeStampsToken() const;
|
||||
int32_t EOT() const;
|
||||
int32_t SOT() const;
|
||||
int32_t TextCtx() const;
|
||||
int32_t VocabSize() const;
|
||||
int32_t Translate() const;
|
||||
bool IsMultiLingual() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
Reference in New Issue
Block a user