Support decoding with byte-level BPE (bbpe) models. (#1633)

This commit is contained in:
Fangjun Kuang
2024-12-20 19:21:32 +08:00
committed by GitHub
parent 7192e576a9
commit b76cd9033a
11 changed files with 270 additions and 10 deletions

View File

@@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/symbol-table.h"
#include <cassert>
#include <cctype>
#include <fstream>
#include <sstream>
#include <string>
@@ -22,8 +23,10 @@
#endif
#include "sherpa-onnx/csrc/base64-decode.h"
#include "sherpa-onnx/csrc/bbpe.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) {
TrimRight(s, t);
TrimLeft(s, t);
}
bool IsByteBPE(const char *s, int32_t n) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(s);
if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
return IsByteBPE(s + 3, n - 3);
}
for (int32_t i = 0; i != n; ++i) {
if (p[i] > 0xc6) {
return false;
}
}
return true;
}
bool IsByteBPE(const std::unordered_map<std::string, int32_t> &sym2id) {
uint8_t max_v = 0;
for (const auto &p : sym2id) {
const auto &s = p.first;
if (!IsByteBPE(s.c_str(), s.size())) {
return false;
}
uint8_t m = 0;
if (s.size() >= 3) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(s.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
if (s.size() > 3) {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()) + 3,
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
} else {
m = 0;
}
} else {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()),
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
}
} else {
m = *std::max_element(
reinterpret_cast<const uint8_t *>(s.data()),
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
}
max_v = (m > max_v) ? m : max_v;
}
return static_cast<uint8_t>(max_v) == 0xc6;
}
} // namespace
std::unordered_map<std::string, int32_t> ReadTokens(
@@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
Init(is);
}
void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); }
void SymbolTable::Init(std::istream &is) {
sym2id_ = ReadTokens(is, &id2sym_);
is_bbpe_ = IsByteBPE(sym2id_);
}
std::string SymbolTable::ToString() const {
std::ostringstream os;
@@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const {
const std::string SymbolTable::operator[](int32_t id) const {
std::string sym = id2sym_.at(id);
if (sym.size() >= 3) {
if (sym.size() >= 3 && !is_bbpe_) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
@@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const {
}
}
// for byte-level BPE
// for BPE with byte_fallback
// id 0 is blank, id 1 is sos/eos, id 2 is unk
//
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
@@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() {
}
}
std::string SymbolTable::DecodeByteBpe(const std::string &text) const {
if (!is_bbpe_) {
return text;
}
auto v = SplitUtf8(text);
const auto &bbpe_table = GetByteBpeTable();
std::string ans;
for (const auto &s : v) {
if (s == "") {
if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) {
ans.push_back(' ');
}
} else if (bbpe_table.count(s)) {
ans.push_back(bbpe_table.at(s));
} else if (std::isprint(s[0])) {
ans.append(s);
} else {
// Should not happen
SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str());
}
}
// TODO(fangjun): Filter invalid utf-8 sequences
return ans;
}
#if __ANDROID_API__ >= 9
template SymbolTable::SymbolTable(AAssetManager *mgr,
const std::string &filename);