Fixing Whisper Model Token Normalization (#1904)

This commit is contained in:
ivan provalov
2025-02-20 20:58:01 -08:00
committed by GitHub
parent ed922e61b5
commit 94728bfbee
3 changed files with 100 additions and 23 deletions

View File

@@ -23,28 +23,6 @@
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}
const auto &s = sym_table[i];
text += s;
r.tokens.push_back(s);
}
r.text = text;
r.lang = src.lang;
return r;
}
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
@@ -156,7 +134,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std::move(cross_kv.second));
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
@@ -169,6 +146,31 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}
private:
OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) const {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) {
if (!sym_table.Contains(i)) {
continue;
}
std::string s = sym_table[i];
s = ApplyInverseTextNormalization(s);
text += s;
r.tokens.push_back(s);
}
r.text = text;
r.lang = src.lang;
return r;
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;