Fix a punctuation bug (#764)

This commit is contained in:
Fangjun Kuang
2024-04-13 19:08:46 +08:00
committed by GitHub
parent b6ad0436fa
commit 983df28a83
2 changed files with 10 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR) cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx) project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.9.18") set(SHERPA_ONNX_VERSION "1.9.19")
# Disable warning about # Disable warning about
# #

View File

@@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
int32_t dot_index = -1; int32_t dot_index = -1;
int32_t comma_index = -1; int32_t comma_index = -1;
for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) { for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
int32_t punct_id = this_punctuations[m]; int32_t punct_id = this_punctuations[m];
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
@@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
} }
} else { } else {
last = this_start + dot_index + 1; last = this_start + dot_index + 1;
}
if (dot_index != 1) {
punctuations.insert(punctuations.end(), this_punctuations.begin(), punctuations.insert(punctuations.end(), this_punctuations.begin(),
this_punctuations.begin() + (dot_index + 1)); this_punctuations.begin() + (dot_index + 1));
} }
} // for (int32_t i = 0; i != num_segments; ++i) } // for (int32_t i = 0; i != num_segments; ++i)
if (punctuations.size() != token_ids.size() &&
punctuations.size() + 1 == token_ids.size()) {
punctuations.push_back(meta_data.dot_id);
}
if (punctuations.size() != token_ids.size()) {
SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
text.c_str(), static_cast<int32_t>(punctuations.size()),
static_cast<int32_t>(token_ids.size()));
return text;
}
std::string ans; std::string ans;
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) { for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
if (i > tokens.size()) {
break;
}
const std::string &w = tokens[i]; const std::string &w = tokens[i];
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
ans.push_back(' '); ans.push_back(' ');
@@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
ans.append(meta_data.id2punct[punctuations[i]]); ans.append(meta_data.id2punct[punctuations[i]]);
} }
} }
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
ans.push_back(meta_data.dot_id);
}
return ans; return ans;
} }