Fix punctuation (#976)
This commit is contained in:
14
.github/workflows/sanitizer.yaml
vendored
14
.github/workflows/sanitizer.yaml
vendored
@@ -76,6 +76,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -92,13 +100,7 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-ctc.sh
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
|
||||
@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
std::vector<int32_t> punctuations;
|
||||
int32_t last = -1;
|
||||
for (int32_t i = 0; i != num_segments; ++i) {
|
||||
int32_t this_start = i * segment_size; // inclusive
|
||||
int32_t this_end = this_start + segment_size; // exclusive
|
||||
int32_t this_start = i * segment_size; // included
|
||||
int32_t this_end = this_start + segment_size; // not included
|
||||
if (this_end > static_cast<int32_t>(token_ids.size())) {
|
||||
this_end = token_ids.size();
|
||||
}
|
||||
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
int32_t dot_index = -1;
|
||||
int32_t comma_index = -1;
|
||||
|
||||
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
|
||||
for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
|
||||
m >= 1; --m) {
|
||||
int32_t punct_id = this_punctuations[m];
|
||||
|
||||
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
||||
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
}
|
||||
|
||||
if (i == num_segments - 1) {
|
||||
dot_index = token_ids.size() - 1;
|
||||
dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
|
||||
}
|
||||
} else {
|
||||
last = this_start + dot_index + 1;
|
||||
}
|
||||
|
||||
if (dot_index != 1) {
|
||||
if (dot_index != -1) {
|
||||
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
||||
this_punctuations.begin() + (dot_index + 1));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user