Track token scores (#571)

* add export of per-token scores (ys, lm, context)

- for best path of the modified-beam-search decoding of transducer

* refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult

* export per-token scores also for greedy-search (online-transducer)

- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing

* fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
Karel Vesely
2024-02-28 23:28:45 +01:00
committed by GitHub
parent 85d59b5840
commit 38c072dcb2
11 changed files with 155 additions and 49 deletions

View File

@@ -18,56 +18,50 @@
namespace sherpa_onnx {
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
oss << sep << item;
sep = ", ";
}
oss << " ]";
return oss.str();
}
/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
oss << " ]";
return oss.str();
}
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
os << "\"segment\":" << segment << ", ";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"text\""
<< ": ";
os << "\"" << text << "\""
<< ", ";
os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";
sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
sep = ", ";
}
os << "]";
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2)
<< start_time << ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
}