Support multilingual whisper models (#274)
This commit is contained in:
@@ -23,21 +23,227 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::string FixInvalidUtf8(const std::string &s) {
|
||||
int32_t s_size = s.size();
|
||||
|
||||
std::string ans;
|
||||
ans.reserve(s_size);
|
||||
|
||||
for (int32_t i = 0; i < s_size;) {
|
||||
uint8_t c = s[i];
|
||||
if (c < 0x80) {
|
||||
// valid
|
||||
ans.append(1, c);
|
||||
++i;
|
||||
continue;
|
||||
} else if ((c >= 0xc0) && (c < 0xe0)) {
|
||||
// beginning of two bytes
|
||||
if ((i + 1) > (s_size - 1)) {
|
||||
// no subsequent byte. invalid!
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
// valid 2-byte utf-8
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
i += 2;
|
||||
continue;
|
||||
} else if ((c >= 0xe0) && (c < 0xf0)) {
|
||||
// beginning of 3 bytes
|
||||
if ((i + 2) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
i += 3;
|
||||
continue;
|
||||
} else if ((c >= 0xf0) && (c < 0xf8)) {
|
||||
// 4 bytes
|
||||
if ((i + 3) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
i += 4;
|
||||
continue;
|
||||
} else if ((c >= 0xf8) && (c < 0xfc)) {
|
||||
// 5 bytes
|
||||
if ((i + 4) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
i += 5;
|
||||
continue;
|
||||
} else if ((c >= 0xfc) && (c < 0xfe)) {
|
||||
// 6 bytes
|
||||
if ((i + 5) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next5 = s[i + 5];
|
||||
if (!(next5 >= 0x80 && next5 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
ans.append(1, next5);
|
||||
i += 6;
|
||||
continue;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
r.text += s;
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Fix the following error in offline-stream.cc
|
||||
//
|
||||
// j["text"] = text;
|
||||
|
||||
// libc++abi: terminating with uncaught exception of type
|
||||
// nlohmann::json_abi_v3_11_2::detail::type_error:
|
||||
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
|
||||
|
||||
#if 0
|
||||
r.text = FixInvalidUtf8(text);
|
||||
#else
|
||||
r.text = text;
|
||||
#endif
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
symbol_table_.ApplyBase64Decode();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
||||
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
|
||||
config_.model_config.whisper, model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for whisper. Given %s",
|
||||
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
|
||||
auto results =
|
||||
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user