Fix punctuation (#976)
This commit is contained in:
@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
std::vector<int32_t> punctuations;
|
||||
int32_t last = -1;
|
||||
for (int32_t i = 0; i != num_segments; ++i) {
|
||||
int32_t this_start = i * segment_size; // inclusive
|
||||
int32_t this_end = this_start + segment_size; // exclusive
|
||||
int32_t this_start = i * segment_size; // included
|
||||
int32_t this_end = this_start + segment_size; // not included
|
||||
if (this_end > static_cast<int32_t>(token_ids.size())) {
|
||||
this_end = token_ids.size();
|
||||
}
|
||||
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
int32_t dot_index = -1;
|
||||
int32_t comma_index = -1;
|
||||
|
||||
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
|
||||
for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
|
||||
m >= 1; --m) {
|
||||
int32_t punct_id = this_punctuations[m];
|
||||
|
||||
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
||||
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
}
|
||||
|
||||
if (i == num_segments - 1) {
|
||||
dot_index = token_ids.size() - 1;
|
||||
dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
|
||||
}
|
||||
} else {
|
||||
last = this_start + dot_index + 1;
|
||||
}
|
||||
|
||||
if (dot_index != 1) {
|
||||
if (dot_index != -1) {
|
||||
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
||||
this_punctuations.begin() + (dot_index + 1));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user