Fixing Whisper Model Token Normalization (#1904)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -134,3 +134,5 @@ us_gold.json
|
||||
us_silver.json
|
||||
kokoro-multi-lang-v1_0
|
||||
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
|
||||
cmake-build-debug
|
||||
README-DEV.txt
|
||||
|
||||
@@ -23,28 +23,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = text;
|
||||
r.lang = src.lang;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
||||
@@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
std::move(cross_kv.second));
|
||||
|
||||
auto r = Convert(results[0], symbol_table_);
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
@@ -169,6 +146,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) const {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.Contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string s = sym_table[i];
|
||||
s = ApplyInverseTextNormalization(s);
|
||||
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
r.text = text;
|
||||
r.lang = src.lang;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
|
||||
@@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
|
||||
EXPECT_EQ(s.size() + 4, v.size());
|
||||
}
|
||||
|
||||
|
||||
// Tests for sanitizeUtf8
|
||||
TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
|
||||
std::string input = "Valid UTF-8 🌍";
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) {
|
||||
std::string input = "Invalid \xFF UTF-8";
|
||||
std::string expected = "Invalid UTF-8";
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) {
|
||||
std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence
|
||||
std::string expected = "Broken ";
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) {
|
||||
std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences
|
||||
std::string expected = "Test ";
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) {
|
||||
std::string input = "\x20\xC4"; // Space followed by an invalid byte
|
||||
std::string expected = " "; // 0xC4 removed
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) {
|
||||
std::string input = "Edge 🏆💯";
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) {
|
||||
std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF";
|
||||
std::string expected = "Mix € 😁 "; // Invalid bytes removed
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) {
|
||||
std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4)
|
||||
std::string expected = " "; // Space remains, 0xC4 is removed
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) {
|
||||
std::string input = "Hello \xc4 world"; // Invalid `0xC4`
|
||||
std::string expected = "Hello world"; // `0xC4` should be removed
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) {
|
||||
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
|
||||
std::string expected = " "; // `0xc4` should be removed, space remains
|
||||
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||
}
|
||||
|
||||
TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) {
|
||||
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
|
||||
std::string output = RemoveInvalidUtf8Sequences(input);
|
||||
|
||||
std::cout << "Processed string: ";
|
||||
for (unsigned char c : output) {
|
||||
printf("\\x%02x ", c);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user