Fixing Whisper Model Token Normalization (#1904)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -134,3 +134,5 @@ us_gold.json
|
|||||||
us_silver.json
|
us_silver.json
|
||||||
kokoro-multi-lang-v1_0
|
kokoro-multi-lang-v1_0
|
||||||
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
|
sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
|
||||||
|
cmake-build-debug
|
||||||
|
README-DEV.txt
|
||||||
|
|||||||
@@ -23,28 +23,6 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
|
||||||
const SymbolTable &sym_table) {
|
|
||||||
OfflineRecognitionResult r;
|
|
||||||
r.tokens.reserve(src.tokens.size());
|
|
||||||
|
|
||||||
std::string text;
|
|
||||||
for (auto i : src.tokens) {
|
|
||||||
if (!sym_table.Contains(i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto &s = sym_table[i];
|
|
||||||
text += s;
|
|
||||||
r.tokens.push_back(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
r.text = text;
|
|
||||||
r.lang = src.lang;
|
|
||||||
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||||
public:
|
public:
|
||||||
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
||||||
@@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
std::move(cross_kv.second));
|
std::move(cross_kv.second));
|
||||||
|
|
||||||
auto r = Convert(results[0], symbol_table_);
|
auto r = Convert(results[0], symbol_table_);
|
||||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
} catch (const Ort::Exception &ex) {
|
} catch (const Ort::Exception &ex) {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
@@ -169,6 +146,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table) const {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
if (!sym_table.Contains(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string s = sym_table[i];
|
||||||
|
s = ApplyInverseTextNormalization(s);
|
||||||
|
|
||||||
|
text += s;
|
||||||
|
r.tokens.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
r.text = text;
|
||||||
|
r.lang = src.lang;
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineRecognizerConfig config_;
|
OfflineRecognizerConfig config_;
|
||||||
SymbolTable symbol_table_;
|
SymbolTable symbol_table_;
|
||||||
|
|||||||
@@ -55,4 +55,77 @@ TEST(RemoveInvalidUtf8Sequences, Case1) {
|
|||||||
EXPECT_EQ(s.size() + 4, v.size());
|
EXPECT_EQ(s.size() + 4, v.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Tests for sanitizeUtf8
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) {
|
||||||
|
std::string input = "Valid UTF-8 🌍";
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) {
|
||||||
|
std::string input = "Invalid \xFF UTF-8";
|
||||||
|
std::string expected = "Invalid UTF-8";
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) {
|
||||||
|
std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence
|
||||||
|
std::string expected = "Broken ";
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) {
|
||||||
|
std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences
|
||||||
|
std::string expected = "Test ";
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) {
|
||||||
|
std::string input = "\x20\xC4"; // Space followed by an invalid byte
|
||||||
|
std::string expected = " "; // 0xC4 removed
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) {
|
||||||
|
std::string input = "Edge 🏆💯";
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) {
|
||||||
|
std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF";
|
||||||
|
std::string expected = "Mix € 😁 "; // Invalid bytes removed
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) {
|
||||||
|
std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4)
|
||||||
|
std::string expected = " "; // Space remains, 0xC4 is removed
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) {
|
||||||
|
std::string input = "Hello \xc4 world"; // Invalid `0xC4`
|
||||||
|
std::string expected = "Hello world"; // `0xC4` should be removed
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) {
|
||||||
|
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
|
||||||
|
std::string expected = " "; // `0xc4` should be removed, space remains
|
||||||
|
EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) {
|
||||||
|
std::string input = "\x20\xc4"; // Space followed by invalid `0xc4`
|
||||||
|
std::string output = RemoveInvalidUtf8Sequences(input);
|
||||||
|
|
||||||
|
std::cout << "Processed string: ";
|
||||||
|
for (unsigned char c : output) {
|
||||||
|
printf("\\x%02x ", c);
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
|
||||||
|
EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
Reference in New Issue
Block a user