Add CTC HLG decoding using OpenFst (#349)
This commit is contained in:
@@ -34,8 +34,8 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<int64_t> shape =
|
||||
features_length.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out_features_length)};
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(out[0]));
|
||||
ans.push_back(std::move(out_features_length));
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
@@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||
|
||||
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user