Fix punctuation (#976)
This commit is contained in:
14
.github/workflows/sanitizer.yaml
vendored
14
.github/workflows/sanitizer.yaml
vendored
@@ -76,6 +76,14 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test offline punctuation
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-offline-punctuation
|
||||||
|
|
||||||
|
.github/scripts/test-offline-punctuation.sh
|
||||||
|
|
||||||
- name: Test offline transducer
|
- name: Test offline transducer
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -92,13 +100,7 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-ctc.sh
|
.github/scripts/test-online-ctc.sh
|
||||||
|
|
||||||
- name: Test offline punctuation
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
export PATH=$PWD/build/bin:$PATH
|
|
||||||
export EXE=sherpa-onnx-offline-punctuation
|
|
||||||
|
|
||||||
.github/scripts/test-offline-punctuation.sh
|
|
||||||
|
|
||||||
- name: Test C API
|
- name: Test C API
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
std::vector<int32_t> punctuations;
|
std::vector<int32_t> punctuations;
|
||||||
int32_t last = -1;
|
int32_t last = -1;
|
||||||
for (int32_t i = 0; i != num_segments; ++i) {
|
for (int32_t i = 0; i != num_segments; ++i) {
|
||||||
int32_t this_start = i * segment_size; // inclusive
|
int32_t this_start = i * segment_size; // included
|
||||||
int32_t this_end = this_start + segment_size; // exclusive
|
int32_t this_end = this_start + segment_size; // not included
|
||||||
if (this_end > static_cast<int32_t>(token_ids.size())) {
|
if (this_end > static_cast<int32_t>(token_ids.size())) {
|
||||||
this_end = token_ids.size();
|
this_end = token_ids.size();
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
int32_t dot_index = -1;
|
int32_t dot_index = -1;
|
||||||
int32_t comma_index = -1;
|
int32_t comma_index = -1;
|
||||||
|
|
||||||
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
|
for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
|
||||||
|
m >= 1; --m) {
|
||||||
int32_t punct_id = this_punctuations[m];
|
int32_t punct_id = this_punctuations[m];
|
||||||
|
|
||||||
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
||||||
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (i == num_segments - 1) {
|
if (i == num_segments - 1) {
|
||||||
dot_index = token_ids.size() - 1;
|
dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
last = this_start + dot_index + 1;
|
last = this_start + dot_index + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dot_index != 1) {
|
if (dot_index != -1) {
|
||||||
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
||||||
this_punctuations.begin() + (dot_index + 1));
|
this_punctuations.begin() + (dot_index + 1));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user