54 lines
1.5 KiB
C++
54 lines
1.5 KiB
C++
// sherpa-onnx/csrc/fst-utils.cc
|
|
//
|
|
// Copyright (c) 2024 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/fst-utils.h"
|
|
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
// This function is copied from kaldi.
|
|
//
|
|
// @param filename Path to a StdVectorFst or StdConstFst graph
|
|
// @return The caller should free the returned pointer using `delete` to
|
|
// avoid memory leak.
|
|
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
|
// read decoding network FST
|
|
std::ifstream is(filename, std::ios::binary);
|
|
if (!is.good()) {
|
|
SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str());
|
|
}
|
|
|
|
fst::FstHeader hdr;
|
|
if (!hdr.Read(is, "<unknown>")) {
|
|
SHERPA_ONNX_LOGE("Reading FST: error reading FST header.");
|
|
}
|
|
|
|
if (hdr.ArcType() != fst::StdArc::Type()) {
|
|
SHERPA_ONNX_LOGE("FST with arc type %s not supported",
|
|
hdr.ArcType().c_str());
|
|
}
|
|
fst::FstReadOptions ropts("<unspecified>", &hdr);
|
|
|
|
fst::Fst<fst::StdArc> *decode_fst = nullptr;
|
|
|
|
if (hdr.FstType() == "vector") {
|
|
decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts);
|
|
} else if (hdr.FstType() == "const") {
|
|
decode_fst = fst::ConstFst<fst::StdArc>::Read(is, ropts);
|
|
} else {
|
|
SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s",
|
|
hdr.FstType().c_str());
|
|
}
|
|
|
|
if (decode_fst == nullptr) { // fst code will warn.
|
|
SHERPA_ONNX_LOGE("Error reading FST (after reading header).");
|
|
return nullptr;
|
|
} else {
|
|
return decode_fst;
|
|
}
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|