Support getting word IDs for CTC HLG decoding. (#978)
This commit is contained in:
@@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// The decoded word IDs
|
||||
/// Note: tokens.size() is usually not equal to words.size()
|
||||
/// words is empty for greedy search decoding.
|
||||
/// it is not empty when an HLG graph or an HLG graph is used.
|
||||
std::vector<int32_t> words;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
/// Note: The index is after subsampling
|
||||
///
|
||||
/// tokens.size() == timestamps.size()
|
||||
std::vector<int32_t> timestamps;
|
||||
};
|
||||
|
||||
|
||||
@@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
|
||||
// -1 here since the input labels are incremented during graph
|
||||
// construction
|
||||
r.tokens.push_back(arc.ilabel - 1);
|
||||
if (arc.olabel != 0) {
|
||||
r.words.push_back(arc.olabel);
|
||||
}
|
||||
|
||||
r.timestamps.push_back(t);
|
||||
prev = arc.ilabel;
|
||||
|
||||
@@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode(
|
||||
|
||||
if (timestamps.size() == results[i].tokens.size()) {
|
||||
results[i].timestamps = std::move(timestamps);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
|
||||
static_cast<int32_t>(results[i].tokens.size()),
|
||||
static_cast<int32_t>(timestamps.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.words = std::move(src.words);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
@@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const {
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
sep = "";
|
||||
|
||||
os << "\""
|
||||
<< "words"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
for (int32_t w : words) {
|
||||
os << sep << w;
|
||||
sep = ", ";
|
||||
}
|
||||
|
||||
os << "]";
|
||||
os << "}";
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ struct OfflineRecognitionResult {
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
std::vector<int32_t> words;
|
||||
|
||||
std::string AsJsonString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// The decoded word IDs
|
||||
/// Note: tokens.size() is usually not equal to words.size()
|
||||
/// words is empty for greedy search decoding.
|
||||
/// it is not empty when an HLG graph or an HLG graph is used.
|
||||
std::vector<int32_t> words;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
/// Note: The index is after subsampling
|
||||
///
|
||||
/// tokens.size() == timestamps.size()
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
@@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
bool ok = decoder->GetBestPath(&fst_out);
|
||||
if (ok) {
|
||||
std::vector<int32_t> isymbols_out;
|
||||
std::vector<int32_t> osymbols_out_unused;
|
||||
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
|
||||
&osymbols_out_unused, nullptr);
|
||||
std::vector<int32_t> osymbols_out;
|
||||
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
|
||||
nullptr);
|
||||
std::vector<int64_t> tokens;
|
||||
tokens.reserve(isymbols_out.size());
|
||||
|
||||
@@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
}
|
||||
|
||||
result->tokens = std::move(tokens);
|
||||
result->words = std::move(osymbols_out);
|
||||
result->timestamps = std::move(timestamps);
|
||||
// no need to set frame_offset
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
|
||||
}
|
||||
|
||||
r.segment = segment;
|
||||
r.words = std::move(src.words);
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
return r;
|
||||
|
||||
@@ -22,14 +22,16 @@ namespace sherpa_onnx {
|
||||
template <typename T>
|
||||
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
if (precision != 0) {
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
}
|
||||
oss << "[";
|
||||
std::string sep = "";
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
oss << "[";
|
||||
std::string sep = "";
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"words\": " << VecToString(words, 0) << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
|
||||
@@ -47,6 +47,8 @@ struct OnlineRecognizerResult {
|
||||
/// log-domain scores from "hot-phrase" contextual boosting
|
||||
std::vector<float> context_scores;
|
||||
|
||||
std::vector<int32_t> words;
|
||||
|
||||
/// ID of this segment
|
||||
/// When an endpoint is detected, it is incremented
|
||||
int32_t segment = 0;
|
||||
|
||||
@@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
||||
})
|
||||
.def_property_readonly("tokens",
|
||||
[](const PyClass &self) { return self.tokens; })
|
||||
.def_property_readonly("words",
|
||||
[](const PyClass &self) { return self.words; })
|
||||
.def_property_readonly(
|
||||
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
||||
}
|
||||
|
||||
@@ -40,6 +40,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
})
|
||||
.def_property_readonly(
|
||||
"segment", [](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"words",
|
||||
[](PyClass &self) -> std::vector<int32_t> { return self.words; })
|
||||
.def_property_readonly(
|
||||
"is_final", [](PyClass &self) -> bool { return self.is_final; })
|
||||
.def("__str__", &PyClass::AsJsonString,
|
||||
|
||||
Reference in New Issue
Block a user