Support decoding with byte-level BPE (bbpe) models. (#1633)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cctype>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -22,8 +23,10 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/base64-decode.h"
|
||||
#include "sherpa-onnx/csrc/bbpe.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) {
|
||||
TrimRight(s, t);
|
||||
TrimLeft(s, t);
|
||||
}
|
||||
|
||||
bool IsByteBPE(const char *s, int32_t n) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(s);
|
||||
if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
return IsByteBPE(s + 3, n - 3);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
if (p[i] > 0xc6) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsByteBPE(const std::unordered_map<std::string, int32_t> &sym2id) {
|
||||
uint8_t max_v = 0;
|
||||
for (const auto &p : sym2id) {
|
||||
const auto &s = p.first;
|
||||
if (!IsByteBPE(s.c_str(), s.size())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint8_t m = 0;
|
||||
if (s.size() >= 3) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(s.c_str());
|
||||
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
if (s.size() > 3) {
|
||||
m = *std::max_element(
|
||||
reinterpret_cast<const uint8_t *>(s.data()) + 3,
|
||||
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
|
||||
} else {
|
||||
m = 0;
|
||||
}
|
||||
} else {
|
||||
m = *std::max_element(
|
||||
reinterpret_cast<const uint8_t *>(s.data()),
|
||||
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
|
||||
}
|
||||
} else {
|
||||
m = *std::max_element(
|
||||
reinterpret_cast<const uint8_t *>(s.data()),
|
||||
reinterpret_cast<const uint8_t *>(s.data()) + s.size());
|
||||
}
|
||||
|
||||
max_v = (m > max_v) ? m : max_v;
|
||||
}
|
||||
|
||||
return static_cast<uint8_t>(max_v) == 0xc6;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unordered_map<std::string, int32_t> ReadTokens(
|
||||
@@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) {
|
||||
Init(is);
|
||||
}
|
||||
|
||||
void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); }
|
||||
void SymbolTable::Init(std::istream &is) {
|
||||
sym2id_ = ReadTokens(is, &id2sym_);
|
||||
is_bbpe_ = IsByteBPE(sym2id_);
|
||||
}
|
||||
|
||||
std::string SymbolTable::ToString() const {
|
||||
std::ostringstream os;
|
||||
@@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const {
|
||||
|
||||
const std::string SymbolTable::operator[](int32_t id) const {
|
||||
std::string sym = id2sym_.at(id);
|
||||
if (sym.size() >= 3) {
|
||||
if (sym.size() >= 3 && !is_bbpe_) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
@@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const {
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// for BPE with byte_fallback
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
//
|
||||
// Note: For moonshine models, 0 is <unk>, 1, is <s>, 2 is</s>
|
||||
@@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() {
|
||||
}
|
||||
}
|
||||
|
||||
std::string SymbolTable::DecodeByteBpe(const std::string &text) const {
|
||||
if (!is_bbpe_) {
|
||||
return text;
|
||||
}
|
||||
auto v = SplitUtf8(text);
|
||||
|
||||
const auto &bbpe_table = GetByteBpeTable();
|
||||
std::string ans;
|
||||
for (const auto &s : v) {
|
||||
if (s == "▁") {
|
||||
if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) {
|
||||
ans.push_back(' ');
|
||||
}
|
||||
} else if (bbpe_table.count(s)) {
|
||||
ans.push_back(bbpe_table.at(s));
|
||||
} else if (std::isprint(s[0])) {
|
||||
ans.append(s);
|
||||
} else {
|
||||
// Should not happen
|
||||
SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): Filter invalid utf-8 sequences
|
||||
return ans;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template SymbolTable::SymbolTable(AAssetManager *mgr,
|
||||
const std::string &filename);
|
||||
|
||||
Reference in New Issue
Block a user