@@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
assert(!sym.empty());
|
||||
|
||||
// for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,
|
||||
// there is a conflict between the real byte 0x20 and ▁, so we disable
|
||||
// the following check.
|
||||
//
|
||||
// Note: Only id2sym_ matters as we use it to convert ID to symbols.
|
||||
#if 0
|
||||
// we disable the test here since for some multi-lingual BPE models
|
||||
// from NeMo, the same symbol can appear multiple times with different IDs.
|
||||
@@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
const std::string &SymbolTable::operator[](int32_t id) const {
|
||||
return id2sym_.at(id);
|
||||
const std::string SymbolTable::operator[](int32_t id) const {
|
||||
std::string sym = id2sym_.at(id);
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
// for byte-level BPE
|
||||
// id 0 is blank, id 1 is sos/eos, id 2 is unk
|
||||
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
|
||||
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::uppercase << (id - 3);
|
||||
|
||||
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
|
||||
uint8_t i = id - 3;
|
||||
sym = std::string(&i, &i + 1);
|
||||
}
|
||||
}
|
||||
return sym;
|
||||
}
|
||||
|
||||
int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||
|
||||
Reference in New Issue
Block a user