Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
@@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
[](PyClass &self) -> float { return self.start_time; })
|
||||
.def_property_readonly(
|
||||
"timestamps",
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
|
||||
.def_property_readonly(
|
||||
"ys_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
|
||||
.def_property_readonly(
|
||||
"lm_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
|
||||
.def_property_readonly(
|
||||
"context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"segment",
|
||||
[](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"is_final",
|
||||
[](PyClass &self) -> bool { return self.is_final; })
|
||||
.def("as_json_string", &PyClass::AsJsonString,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
|
||||
@@ -503,6 +503,9 @@ class OnlineRecognizer(object):
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text.strip()
|
||||
|
||||
def get_result_as_json_string(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).as_json_string()
|
||||
|
||||
def tokens(self, s: OnlineStream) -> List[str]:
|
||||
return self.recognizer.get_result(s).tokens
|
||||
|
||||
@@ -512,6 +515,15 @@ class OnlineRecognizer(object):
|
||||
def start_time(self, s: OnlineStream) -> float:
|
||||
return self.recognizer.get_result(s).start_time
|
||||
|
||||
def ys_probs(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).ys_probs
|
||||
|
||||
def lm_probs(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).lm_probs
|
||||
|
||||
def context_scores(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).context_scores
|
||||
|
||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_endpoint(s)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user