This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
2024-04-13 19:08:46 +08:00

167 lines
5.1 KiB
C++

// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
public:
explicit OfflinePunctuationCtTransformerImpl(
const OfflinePunctuationConfig &config)
: config_(config), model_(config.model) {}
std::string AddPunctuation(const std::string &text) const override {
if (text.empty()) {
return {};
}
std::vector<std::string> tokens = SplitUtf8(text);
std::vector<int32_t> token_ids;
token_ids.reserve(tokens.size());
const auto &meta_data = model_.GetModelMetadata();
for (const auto &t : tokens) {
std::string token = ToLowerCase(t);
if (meta_data.token2id.count(token)) {
token_ids.push_back(meta_data.token2id.at(token));
} else {
token_ids.push_back(meta_data.unk_id);
}
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t segment_size = 20;
int32_t max_len = 200;
int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size;
std::vector<int32_t> punctuations;
int32_t last = -1;
for (int32_t i = 0; i != num_segments; ++i) {
int32_t this_start = i * segment_size; // inclusive
int32_t this_end = this_start + segment_size; // exclusive
if (this_end > token_ids.size()) {
this_end = token_ids.size();
}
if (last != -1) {
this_start = last;
}
// token_ids[this_start:this_end] is sent to the model
std::array<int64_t, 2> x_shape = {1, this_end - this_start};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start,
x_shape[1], x_shape.data(), x_shape.size());
int64_t len_shape = 1;
int32_t len = x_shape[1];
Ort::Value x_len =
Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
Ort::Value out = model_.Forward(std::move(x), std::move(x_len));
// [N, T, num_punctuations]
std::vector<int64_t> out_shape =
out.GetTensorTypeAndShapeInfo().GetShape();
assert(out_shape[0] == 1);
assert(out_shape[1] == len);
assert(out_shape[2] == meta_data.num_punctuations);
std::vector<int32_t> this_punctuations;
this_punctuations.reserve(len);
const float *p = out.GetTensorData<float>();
for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) {
auto index = static_cast<int32_t>(std::distance(
p, std::max_element(p, p + meta_data.num_punctuations)));
this_punctuations.push_back(index);
} // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations)
int32_t dot_index = -1;
int32_t comma_index = -1;
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
int32_t punct_id = this_punctuations[m];
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
dot_index = m;
break;
}
if (comma_index == -1 && punct_id == meta_data.comma_id) {
comma_index = m;
}
} // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k)
if (dot_index == -1 && len >= max_len && comma_index != -1) {
dot_index = comma_index;
this_punctuations[dot_index] = meta_data.dot_id;
}
if (dot_index == -1) {
if (last == -1) {
last = this_start;
}
if (i == num_segments - 1) {
dot_index = token_ids.size() - 1;
}
} else {
last = this_start + dot_index + 1;
}
if (dot_index != 1) {
punctuations.insert(punctuations.end(), this_punctuations.begin(),
this_punctuations.begin() + (dot_index + 1));
}
} // for (int32_t i = 0; i != num_segments; ++i)
std::string ans;
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
if (i > tokens.size()) {
break;
}
const std::string &w = tokens[i];
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
ans.push_back(' ');
}
ans.append(w);
if (punctuations[i] != meta_data.underline_id) {
ans.append(meta_data.id2punct[punctuations[i]]);
}
}
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
ans.push_back(meta_data.dot_id);
}
return ans;
}
private:
OfflinePunctuationConfig config_;
OfflineCtTransformerModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_