Fix a punctuation bug (#764)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.9.18")
|
set(SHERPA_ONNX_VERSION "1.9.19")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
int32_t dot_index = -1;
|
int32_t dot_index = -1;
|
||||||
int32_t comma_index = -1;
|
int32_t comma_index = -1;
|
||||||
|
|
||||||
for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
|
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
|
||||||
int32_t punct_id = this_punctuations[m];
|
int32_t punct_id = this_punctuations[m];
|
||||||
|
|
||||||
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
||||||
@@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
last = this_start + dot_index + 1;
|
last = this_start + dot_index + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dot_index != 1) {
|
||||||
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
||||||
this_punctuations.begin() + (dot_index + 1));
|
this_punctuations.begin() + (dot_index + 1));
|
||||||
}
|
}
|
||||||
} // for (int32_t i = 0; i != num_segments; ++i)
|
} // for (int32_t i = 0; i != num_segments; ++i)
|
||||||
|
|
||||||
if (punctuations.size() != token_ids.size() &&
|
|
||||||
punctuations.size() + 1 == token_ids.size()) {
|
|
||||||
punctuations.push_back(meta_data.dot_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (punctuations.size() != token_ids.size()) {
|
|
||||||
SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
|
|
||||||
text.c_str(), static_cast<int32_t>(punctuations.size()),
|
|
||||||
static_cast<int32_t>(token_ids.size()));
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ans;
|
std::string ans;
|
||||||
|
|
||||||
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
|
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
|
||||||
|
if (i > tokens.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
const std::string &w = tokens[i];
|
const std::string &w = tokens[i];
|
||||||
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
|
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
|
||||||
ans.push_back(' ');
|
ans.push_back(' ');
|
||||||
@@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
ans.append(meta_data.id2punct[punctuations[i]]);
|
ans.append(meta_data.id2punct[punctuations[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
|
||||||
|
ans.push_back(meta_data.dot_id);
|
||||||
|
}
|
||||||
|
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user