Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
@@ -18,56 +18,50 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<typename T>
|
||||
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
|
||||
os << "\"segment\":" << segment << ", ";
|
||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
|
||||
os << "\"text\""
|
||||
<< ": ";
|
||||
os << "\"" << text << "\""
|
||||
<< ", ";
|
||||
|
||||
os << "\""
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
<< "\""
|
||||
<< ":";
|
||||
os << "[";
|
||||
|
||||
sep = "";
|
||||
auto oldFlags = os.flags();
|
||||
for (const auto &t : tokens) {
|
||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||
os << sep << "\""
|
||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||
<< ">"
|
||||
<< "\"";
|
||||
os.flags(oldFlags);
|
||||
} else {
|
||||
os << sep << "\"" << t << "\"";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||
<< start_time << ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
os << "}";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user