Support adding punctuations to the speech recogntion result (#761)
This commit is contained in:
41
.github/scripts/test-offline-punctuation.sh
vendored
Executable file
41
.github/scripts/test-offline-punctuation.sh
vendored
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Download model "
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
|
||||
ls -lh $repo
|
||||
|
||||
$EXE \
|
||||
--debug=1 \
|
||||
--ct-transformer=$repo/model.onnx \
|
||||
"这是一个测试你好吗How are you我很好thank you are you ok谢谢你"
|
||||
|
||||
$EXE \
|
||||
--debug=1 \
|
||||
--ct-transformer=$repo/model.onnx \
|
||||
"我们都是木头人不会说话不会动"
|
||||
|
||||
$EXE \
|
||||
--debug=1 \
|
||||
--ct-transformer=$repo/model.onnx \
|
||||
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry"
|
||||
|
||||
rm -rf $repo
|
||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -16,6 +16,7 @@ on:
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -34,6 +35,7 @@ on:
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -126,6 +128,14 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -16,6 +16,7 @@ on:
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -33,6 +34,7 @@ on:
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -105,6 +107,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
11
.github/workflows/windows-x64.yaml
vendored
11
.github/workflows/windows-x64.yaml
vendored
@@ -15,6 +15,7 @@ on:
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -30,6 +31,7 @@ on:
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -72,6 +74,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation.exe
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -82,7 +92,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-c-api.sh
|
||||
|
||||
|
||||
- name: Test Audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/windows-x86.yaml
vendored
10
.github/workflows/windows-x86.yaml
vendored
@@ -15,6 +15,7 @@ on:
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -30,6 +31,7 @@ on:
|
||||
- '.github/scripts/test-offline-tts.sh'
|
||||
- '.github/scripts/test-online-ctc.sh'
|
||||
- '.github/scripts/test-audio-tagging.sh'
|
||||
- '.github/scripts/test-offline-punctuation.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -72,6 +74,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline punctuation
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-punctuation.exe
|
||||
|
||||
.github/scripts/test-offline-punctuation.sh
|
||||
|
||||
- name: Test spoken language identification (C API)
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -46,14 +46,15 @@ def enable_alsa():
|
||||
def get_binaries():
|
||||
binaries = [
|
||||
"sherpa-onnx",
|
||||
"sherpa-onnx-offline-audio-tagging",
|
||||
"sherpa-onnx-keyword-spotter",
|
||||
"sherpa-onnx-microphone",
|
||||
"sherpa-onnx-microphone-offline",
|
||||
"sherpa-onnx-microphone-offline-audio-tagging",
|
||||
"sherpa-onnx-microphone-offline-speaker-identification",
|
||||
"sherpa-onnx-offline",
|
||||
"sherpa-onnx-offline-audio-tagging",
|
||||
"sherpa-onnx-offline-language-identification",
|
||||
"sherpa-onnx-offline-punctuation",
|
||||
"sherpa-onnx-offline-tts",
|
||||
"sherpa-onnx-offline-tts-play",
|
||||
"sherpa-onnx-offline-websocket-server",
|
||||
|
||||
@@ -408,8 +408,11 @@ def main():
|
||||
vad_config.silero_vad.min_silence_duration = 0.25
|
||||
vad_config.silero_vad.min_speech_duration = 0.25
|
||||
vad_config.sample_rate = g_sample_rate
|
||||
if not vad_config.validate():
|
||||
raise ValueError("Errors in vad config")
|
||||
|
||||
window_size = vad_config.silero_vad.window_size
|
||||
|
||||
vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100)
|
||||
|
||||
samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
|
||||
|
||||
@@ -121,6 +121,14 @@ list(APPEND sources
|
||||
offline-zipformer-audio-tagging-model.cc
|
||||
)
|
||||
|
||||
# punctuation
|
||||
list(APPEND sources
|
||||
offline-ct-transformer-model.cc
|
||||
offline-punctuation-impl.cc
|
||||
offline-punctuation-model-config.cc
|
||||
offline-punctuation.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND sources
|
||||
lexicon.cc
|
||||
@@ -201,9 +209,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
@@ -213,9 +222,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx
|
||||
sherpa-onnx-keyword-spotter
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-audio-tagging
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-punctuation
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND main_exes
|
||||
@@ -260,11 +270,11 @@ endif()
|
||||
|
||||
if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline-audio-tagging sherpa-onnx-alsa-offline-audio-tagging.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc)
|
||||
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
|
||||
@@ -74,11 +74,6 @@ static std::vector<std::string> ProcessHeteronyms(
|
||||
return ans;
|
||||
}
|
||||
|
||||
static void ToLowerCase(std::string *in_out) {
|
||||
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
}
|
||||
|
||||
// Note: We don't use SymbolTable here since tokens may contain a blank
|
||||
// in the first column
|
||||
static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
|
||||
|
||||
@@ -118,6 +118,24 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// read a vector of strings separated by sep
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
SplitStringToVector(value.get(), sep, false, &dst); \
|
||||
\
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
|
||||
src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Read a string
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||
do { \
|
||||
|
||||
29
sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
Normal file
29
sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCtTransformerModelMetaData {
|
||||
std::unordered_map<std::string, int32_t> token2id;
|
||||
std::unordered_map<std::string, int32_t> punct2id;
|
||||
std::vector<std::string> id2punct;
|
||||
|
||||
int32_t unk_id;
|
||||
int32_t dot_id;
|
||||
int32_t comma_id;
|
||||
int32_t quest_id;
|
||||
int32_t pause_id;
|
||||
int32_t underline_id;
|
||||
int32_t num_punctuations;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
|
||||
164
sherpa-onnx/csrc/offline-ct-transformer-model.cc
Normal file
164
sherpa-onnx/csrc/offline-ct-transformer-model.cc
Normal file
@@ -0,0 +1,164 @@
|
||||
// sherpa-onnx/csrc/offline-ct-transformer-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCtTransformerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflinePunctuationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.ct_transformer);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflinePunctuationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.ct_transformer);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Forward(Ort::Value text, Ort::Value text_len) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(text), std::move(text_len)};
|
||||
|
||||
auto ans =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(ans[0]);
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(tokens, "tokens", "|");
|
||||
|
||||
int32_t vocab_size;
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size, "vocab_size");
|
||||
if (tokens.size() != vocab_size) {
|
||||
SHERPA_ONNX_LOGE("tokens.size() %d != vocab_size %d",
|
||||
static_cast<int32_t>(tokens.size()), vocab_size);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(meta_data_.id2punct,
|
||||
"punctuations", "|");
|
||||
|
||||
std::string unk_symbol;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(unk_symbol, "unk_symbol");
|
||||
|
||||
// output shape is (N, T, num_punctuations)
|
||||
meta_data_.num_punctuations =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
|
||||
int32_t i = 0;
|
||||
for (const auto &t : tokens) {
|
||||
meta_data_.token2id[t] = i;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (const auto &p : meta_data_.id2punct) {
|
||||
meta_data_.punct2id[p] = i;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
meta_data_.unk_id = meta_data_.token2id.at(unk_symbol);
|
||||
|
||||
meta_data_.dot_id = meta_data_.punct2id.at("。");
|
||||
meta_data_.comma_id = meta_data_.punct2id.at(",");
|
||||
meta_data_.quest_id = meta_data_.punct2id.at("?");
|
||||
meta_data_.pause_id = meta_data_.punct2id.at("、");
|
||||
meta_data_.underline_id = meta_data_.punct2id.at("_");
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "vocab_size: " << meta_data_.token2id.size() << "\n";
|
||||
os << "num_punctuations: " << meta_data_.num_punctuations << "\n";
|
||||
os << "punctuations: ";
|
||||
for (const auto &s : meta_data_.id2punct) {
|
||||
os << s << " ";
|
||||
}
|
||||
os << "\n";
|
||||
SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflinePunctuationModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
OfflineCtTransformerModelMetaData meta_data_;
|
||||
};
|
||||
|
||||
OfflineCtTransformerModel::OfflineCtTransformerModel(
|
||||
const OfflinePunctuationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCtTransformerModel::OfflineCtTransformerModel(
|
||||
AAssetManager *mgr, const OfflinePunctuationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineCtTransformerModel::~OfflineCtTransformerModel() = default;
|
||||
|
||||
Ort::Value OfflineCtTransformerModel::Forward(Ort::Value text,
|
||||
Ort::Value text_len) const {
|
||||
return impl_->Forward(std::move(text), std::move(text_len));
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineCtTransformerModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
const OfflineCtTransformerModelMetaData &
|
||||
OfflineCtTransformerModel::GetModelMetadata() const {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
59
sherpa-onnx/csrc/offline-ct-transformer-model.h
Normal file
59
sherpa-onnx/csrc/offline-ct-transformer-model.h
Normal file
@@ -0,0 +1,59 @@
|
||||
// sherpa-onnx/csrc/offline-ct-transformer-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements
|
||||
* https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx/punc_bin.py#L17
|
||||
* from FunASR
|
||||
*/
|
||||
class OfflineCtTransformerModel {
|
||||
public:
|
||||
explicit OfflineCtTransformerModel(
|
||||
const OfflinePunctuationModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCtTransformerModel(AAssetManager *mgr,
|
||||
const OfflinePunctuationModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineCtTransformerModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param text A tensor of shape (N, T) of dtype int32.
|
||||
* @param text A tensor of shape (N) of dtype int32.
|
||||
*
|
||||
* @return Return a tensor
|
||||
* - punctuation_ids: A 2-D tensor of shape (N, T).
|
||||
*/
|
||||
Ort::Value Forward(Ort::Value text, Ort::Value text_len) const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
const OfflineCtTransformerModelMetaData &GetModelMetadata() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
|
||||
170
sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
Normal file
170
sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
Normal file
@@ -0,0 +1,170 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
|
||||
public:
|
||||
explicit OfflinePunctuationCtTransformerImpl(
|
||||
const OfflinePunctuationConfig &config)
|
||||
: config_(config), model_(config.model) {}
|
||||
|
||||
std::string AddPunctuation(const std::string &text) const override {
|
||||
if (text.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::string> tokens = SplitUtf8(text);
|
||||
std::vector<int32_t> token_ids;
|
||||
token_ids.reserve(tokens.size());
|
||||
|
||||
const auto &meta_data = model_.GetModelMetadata();
|
||||
|
||||
for (const auto &t : tokens) {
|
||||
std::string token = ToLowerCase(t);
|
||||
if (meta_data.token2id.count(token)) {
|
||||
token_ids.push_back(meta_data.token2id.at(token));
|
||||
} else {
|
||||
token_ids.push_back(meta_data.unk_id);
|
||||
}
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t segment_size = 20;
|
||||
int32_t max_len = 200;
|
||||
int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size;
|
||||
|
||||
std::vector<int32_t> punctuations;
|
||||
int32_t last = -1;
|
||||
for (int32_t i = 0; i != num_segments; ++i) {
|
||||
int32_t this_start = i * segment_size; // inclusive
|
||||
int32_t this_end = this_start + segment_size; // exclusive
|
||||
if (this_end > token_ids.size()) {
|
||||
this_end = token_ids.size();
|
||||
}
|
||||
|
||||
if (last != -1) {
|
||||
this_start = last;
|
||||
}
|
||||
// token_ids[this_start:this_end] is sent to the model
|
||||
|
||||
std::array<int64_t, 2> x_shape = {1, this_end - this_start};
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start,
|
||||
x_shape[1], x_shape.data(), x_shape.size());
|
||||
|
||||
int64_t len_shape = 1;
|
||||
int32_t len = x_shape[1];
|
||||
Ort::Value x_len =
|
||||
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
|
||||
|
||||
Ort::Value out = model_.Forward(std::move(x), std::move(x_len));
|
||||
|
||||
// [N, T, num_punctuations]
|
||||
std::vector<int64_t> out_shape =
|
||||
out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
assert(out_shape[0] == 1);
|
||||
assert(out_shape[1] == len);
|
||||
assert(out_shape[2] == meta_data.num_punctuations);
|
||||
|
||||
std::vector<int32_t> this_punctuations;
|
||||
this_punctuations.reserve(len);
|
||||
|
||||
const float *p = out.GetTensorData<float>();
|
||||
for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) {
|
||||
auto index = static_cast<int32_t>(std::distance(
|
||||
p, std::max_element(p, p + meta_data.num_punctuations)));
|
||||
this_punctuations.push_back(index);
|
||||
} // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations)
|
||||
|
||||
int32_t dot_index = -1;
|
||||
int32_t comma_index = -1;
|
||||
|
||||
for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
|
||||
int32_t punct_id = this_punctuations[m];
|
||||
|
||||
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
|
||||
dot_index = m;
|
||||
break;
|
||||
}
|
||||
|
||||
if (comma_index == -1 && punct_id == meta_data.comma_id) {
|
||||
comma_index = m;
|
||||
}
|
||||
} // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k)
|
||||
|
||||
if (dot_index == -1 && len >= max_len && comma_index != -1) {
|
||||
dot_index = comma_index;
|
||||
this_punctuations[dot_index] = meta_data.dot_id;
|
||||
}
|
||||
|
||||
if (dot_index == -1) {
|
||||
if (last == -1) {
|
||||
last = this_start;
|
||||
}
|
||||
|
||||
if (i == num_segments - 1) {
|
||||
dot_index = token_ids.size() - 1;
|
||||
}
|
||||
} else {
|
||||
last = this_start + dot_index + 1;
|
||||
|
||||
punctuations.insert(punctuations.end(), this_punctuations.begin(),
|
||||
this_punctuations.begin() + (dot_index + 1));
|
||||
}
|
||||
} // for (int32_t i = 0; i != num_segments; ++i)
|
||||
|
||||
if (punctuations.size() != token_ids.size() &&
|
||||
punctuations.size() + 1 == token_ids.size()) {
|
||||
punctuations.push_back(meta_data.dot_id);
|
||||
}
|
||||
|
||||
if (punctuations.size() != token_ids.size()) {
|
||||
SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
|
||||
text.c_str(), static_cast<int32_t>(punctuations.size()),
|
||||
static_cast<int32_t>(token_ids.size()));
|
||||
return text;
|
||||
}
|
||||
|
||||
std::string ans;
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
|
||||
const std::string &w = tokens[i];
|
||||
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
|
||||
ans.push_back(' ');
|
||||
}
|
||||
ans.append(w);
|
||||
if (punctuations[i] != meta_data.underline_id) {
|
||||
ans.append(meta_data.id2punct[punctuations[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
OfflinePunctuationConfig config_;
|
||||
OfflineCtTransformerModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
|
||||
22
sherpa-onnx/csrc/offline-punctuation-impl.cc
Normal file
22
sherpa-onnx/csrc/offline-punctuation-impl.cc
Normal file
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create(
|
||||
const OfflinePunctuationConfig &config) {
|
||||
if (!config.model.ct_transformer.empty()) {
|
||||
return std::make_unique<OfflinePunctuationCtTransformerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
27
sherpa-onnx/csrc/offline-punctuation-impl.h
Normal file
27
sherpa-onnx/csrc/offline-punctuation-impl.h
Normal file
@@ -0,0 +1,27 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflinePunctuationImpl {
|
||||
public:
|
||||
virtual ~OfflinePunctuationImpl() = default;
|
||||
|
||||
static std::unique_ptr<OfflinePunctuationImpl> Create(
|
||||
const OfflinePunctuationConfig &config);
|
||||
|
||||
virtual std::string AddPunctuation(const std::string &text) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
|
||||
53
sherpa-onnx/csrc/offline-punctuation-model-config.cc
Normal file
53
sherpa-onnx/csrc/offline-punctuation-model-config.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflinePunctuationModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("ct-transformer", &ct_transformer,
|
||||
"Path to the controllable time-delay (CT) transformer model");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflinePunctuationModelConfig::Validate() const {
|
||||
if (ct_transformer.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --ct-transformer");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(ct_transformer)) {
|
||||
SHERPA_ONNX_LOGE("--ct-transformer %s does not exist",
|
||||
ct_transformer.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflinePunctuationModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflinePunctuationModelConfig(";
|
||||
os << "ct_transformer=\"" << ct_transformer << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
38
sherpa-onnx/csrc/offline-punctuation-model-config.h
Normal file
38
sherpa-onnx/csrc/offline-punctuation-model-config.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflinePunctuationModelConfig {
|
||||
std::string ct_transformer;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflinePunctuationModelConfig() = default;
|
||||
|
||||
OfflinePunctuationModelConfig(const std::string &ct_transformer,
|
||||
int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: ct_transformer(ct_transformer),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
42
sherpa-onnx/csrc/offline-punctuation.cc
Normal file
42
sherpa-onnx/csrc/offline-punctuation.cc
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflinePunctuationConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
}
|
||||
|
||||
bool OfflinePunctuationConfig::Validate() const {
|
||||
if (!model.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflinePunctuationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflinePunctuationConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config)
|
||||
: impl_(OfflinePunctuationImpl::Create(config)) {}
|
||||
|
||||
OfflinePunctuation::~OfflinePunctuation() = default;
|
||||
|
||||
std::string OfflinePunctuation::AddPunctuation(const std::string &text) const {
|
||||
return impl_->AddPunctuation(text);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
47
sherpa-onnx/csrc/offline-punctuation.h
Normal file
47
sherpa-onnx/csrc/offline-punctuation.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// sherpa-onnx/csrc/offline-punctuation.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflinePunctuationConfig {
|
||||
OfflinePunctuationModelConfig model;
|
||||
|
||||
OfflinePunctuationConfig() = default;
|
||||
|
||||
explicit OfflinePunctuationConfig(const OfflinePunctuationModelConfig &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OfflinePunctuationImpl;
|
||||
|
||||
class OfflinePunctuation {
|
||||
public:
|
||||
explicit OfflinePunctuation(const OfflinePunctuationConfig &config);
|
||||
|
||||
~OfflinePunctuation();
|
||||
|
||||
// Add punctuation to the input text and return it.
|
||||
std::string AddPunctuation(const std::string &text) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflinePunctuationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
|
||||
@@ -29,7 +29,6 @@ void OnlineWebsocketDecoderConfig::Validate() const {
|
||||
SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0);
|
||||
SHERPA_ONNX_CHECK_GT(max_batch_size, 0);
|
||||
SHERPA_ONNX_CHECK_GT(end_tail_padding, 0);
|
||||
|
||||
}
|
||||
|
||||
void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) {
|
||||
@@ -87,7 +86,8 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
|
||||
c->samples.pop_front();
|
||||
}
|
||||
|
||||
std::vector<float> tail_padding(static_cast<int64_t>(config_.end_tail_padding * sample_rate));
|
||||
std::vector<float> tail_padding(
|
||||
static_cast<int64_t>(config_.end_tail_padding * sample_rate));
|
||||
|
||||
c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size());
|
||||
|
||||
|
||||
@@ -160,4 +160,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OfflinePunctuationModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
@@ -43,6 +44,9 @@ Ort::SessionOptions GetSessionOptions(
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OfflinePunctuationModelConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
68
sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc
Normal file
68
sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc
Normal file
@@ -0,0 +1,68 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Add punctuations to the input text.
|
||||
|
||||
The input text can contain both Chinese and English words.
|
||||
|
||||
Usage:
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
|
||||
./bin/sherpa-onnx-offline-punctuation \
|
||||
--ct-transformer=./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx
|
||||
"你好吗how are you Fantasitic 谢谢我很好你怎么样呢"
|
||||
|
||||
The output text should look like below:
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflinePunctuationConfig config;
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr,
|
||||
"Error: Please provide only 1 position argument containing the "
|
||||
"input text.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Creating OfflinePunctuation ...\n");
|
||||
sherpa_onnx::OfflinePunctuation punct(config);
|
||||
fprintf(stderr, "Started\n");
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
std::string text = po.GetArg(1);
|
||||
std::string text_with_punct = punct.AddPunctuation(text);
|
||||
fprintf(stderr, "Done\n");
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
fprintf(stderr, "Input text: %s\n", text.c_str());
|
||||
fprintf(stderr, "Output text: %s\n", text_with_punct.c_str());
|
||||
}
|
||||
@@ -111,8 +111,8 @@ for a list of pre-trained models to download.
|
||||
fprintf(stderr, "Creating recognizer ...\n");
|
||||
sherpa_onnx::OfflineRecognizer recognizer(config);
|
||||
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Started\n");
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
|
||||
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
|
||||
|
||||
@@ -385,4 +385,16 @@ std::vector<std::string> SplitUtf8(const std::string &text) {
|
||||
return MergeCharactersIntoWords(ans);
|
||||
}
|
||||
|
||||
std::string ToLowerCase(const std::string &s) {
|
||||
std::string ans(s.size(), 0);
|
||||
std::transform(s.begin(), s.end(), ans.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
void ToLowerCase(std::string *in_out) {
|
||||
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -121,6 +121,9 @@ bool ConvertStringToReal(const std::string &str, T *out);
|
||||
|
||||
std::vector<std::string> SplitUtf8(const std::string &text);
|
||||
|
||||
std::string ToLowerCase(const std::string &s);
|
||||
void ToLowerCase(std::string *in_out);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user