Add CTC HLG decoding using OpenFst (#349)

This commit is contained in:
Fangjun Kuang
2023-10-08 11:32:39 +08:00
committed by GitHub
parent c12286fe5e
commit 407602445d
39 changed files with 964 additions and 56 deletions

View File

@@ -34,8 +34,8 @@ class OfflineNemoEncDecCtcModel::Impl {
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
features_length.GetTensorTypeAndShapeInfo().GetShape();
@@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl {
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(out[0]), std::move(out_features_length)};
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(out[0]));
ans.push_back(std::move(out_features_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
@@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}