Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
build
|
||||
*.zip
|
||||
*.tgz
|
||||
*.sw?
|
||||
onnxruntime-*
|
||||
icefall-*
|
||||
run.sh
|
||||
|
||||
@@ -29,9 +29,21 @@ struct Hypothesis {
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// The acoustic probability for each token in ys.
|
||||
// Only used for keyword spotting task.
|
||||
// Used for keyword spotting task.
|
||||
// For transducer mofified beam-search and greedy-search,
|
||||
// this is filled with log_posterior scores.
|
||||
std::vector<float> ys_probs;
|
||||
|
||||
// lm_probs[i] contains the lm score for each token in ys.
|
||||
// Used only in transducer mofified beam-search.
|
||||
// Elements filled only if LM is used.
|
||||
std::vector<float> lm_probs;
|
||||
|
||||
// context_scores[i] contains the context-graph score for each token in ys.
|
||||
// Used only in transducer mofified beam-search.
|
||||
// Elements filled only if `ContextGraph` is used.
|
||||
std::vector<float> context_scores;
|
||||
|
||||
// The total score of ys in log space.
|
||||
// It contains only acoustic scores
|
||||
double log_prob = 0;
|
||||
|
||||
@@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.ys_probs = std::move(src.ys_probs);
|
||||
r.lm_probs = std::move(src.lm_probs);
|
||||
r.context_scores = std::move(src.context_scores);
|
||||
|
||||
r.segment = segment;
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
|
||||
@@ -18,56 +18,50 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<typename T>
|
||||
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
oss << " ]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
|
||||
os << "\"segment\":" << segment << ", ";
|
||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
|
||||
os << "\"text\""
|
||||
<< ": ";
|
||||
os << "\"" << text << "\""
|
||||
<< ", ";
|
||||
|
||||
os << "\""
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
<< "\""
|
||||
<< ":";
|
||||
os << "[";
|
||||
|
||||
sep = "";
|
||||
auto oldFlags = os.flags();
|
||||
for (const auto &t : tokens) {
|
||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||
os << sep << "\""
|
||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||
<< ">"
|
||||
<< "\"";
|
||||
os.flags(oldFlags);
|
||||
} else {
|
||||
os << sep << "\"" << t << "\"";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||
<< start_time << ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
os << "}";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
std::vector<float> ys_probs; //< log-prob scores from ASR model
|
||||
std::vector<float> lm_probs; //< log-prob scores from language model
|
||||
//
|
||||
/// log-domain scores from "hot-phrase" contextual boosting
|
||||
std::vector<float> context_scores;
|
||||
|
||||
/// ID of this segment
|
||||
/// When an endpoint is detected, it is incremented
|
||||
int32_t segment = 0;
|
||||
@@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
|
||||
* "text": "The recognition result",
|
||||
* "tokens": [x, x, x],
|
||||
* "timestamps": [x, x, x],
|
||||
* "ys_probs": [x, x, x],
|
||||
* "lm_probs": [x, x, x],
|
||||
* "context_scores": [x, x, x],
|
||||
* "segment": x,
|
||||
* "start_time": x,
|
||||
* "is_final": true|false
|
||||
|
||||
@@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
frame_offset = other.frame_offset;
|
||||
timestamps = other.timestamps;
|
||||
|
||||
ys_probs = other.ys_probs;
|
||||
lm_probs = other.lm_probs;
|
||||
context_scores = other.context_scores;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
frame_offset = other.frame_offset;
|
||||
timestamps = std::move(other.timestamps);
|
||||
|
||||
ys_probs = std::move(other.ys_probs);
|
||||
lm_probs = std::move(other.lm_probs);
|
||||
context_scores = std::move(other.context_scores);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
std::vector<float> ys_probs;
|
||||
std::vector<float> lm_probs;
|
||||
std::vector<float> context_scores;
|
||||
|
||||
// Cache decoder_out for endpointing
|
||||
Ort::Value decoder_out;
|
||||
|
||||
|
||||
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
||||
r->tokens = std::vector<int64_t>(start, end);
|
||||
}
|
||||
|
||||
|
||||
void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_batch_decoder_out_cached) {
|
||||
auto &r = result->front();
|
||||
std::vector<int64_t> decoder_out_shape =
|
||||
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
if (blank_penalty_ > 0.0) {
|
||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
} else {
|
||||
++r.num_trailing_blanks;
|
||||
}
|
||||
|
||||
// export the per-token log scores
|
||||
if (y != 0 && y != unk_id_) {
|
||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||
// save time by doing it only for
|
||||
// emitted symbols
|
||||
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
|
||||
// now it contains normalized
|
||||
// probability
|
||||
r.ys_probs.push_back(p_logprob[y]);
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
|
||||
@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
||||
r->tokens = std::move(tokens);
|
||||
r->timestamps = std::move(hyp.timestamps);
|
||||
|
||||
// export per-token scores
|
||||
r->ys_probs = std::move(hyp.ys_probs);
|
||||
r->lm_probs = std::move(hyp.lm_probs);
|
||||
r->context_scores = std::move(hyp.context_scores);
|
||||
|
||||
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
||||
}
|
||||
|
||||
@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
// export the per-token log scores
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
const Hypothesis& prev_i = prev[hyp_index];
|
||||
// subtract 'prev[i]' path scores, which were added before
|
||||
// getting topk tokens
|
||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||
new_hyp.ys_probs.push_back(y_prob);
|
||||
|
||||
if (lm_) { // export only when LM is used
|
||||
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
||||
if (lm_scale_ != 0.0) {
|
||||
lm_prob /= lm_scale_; // remove lm-scale
|
||||
}
|
||||
new_hyp.lm_probs.push_back(lm_prob);
|
||||
}
|
||||
|
||||
// export only when `ContextGraph` is used
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
new_hyp.context_scores.push_back(context_score);
|
||||
}
|
||||
}
|
||||
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
cur.push_back(std::move(hyps));
|
||||
|
||||
@@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
[](PyClass &self) -> float { return self.start_time; })
|
||||
.def_property_readonly(
|
||||
"timestamps",
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
|
||||
.def_property_readonly(
|
||||
"ys_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
|
||||
.def_property_readonly(
|
||||
"lm_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
|
||||
.def_property_readonly(
|
||||
"context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"segment",
|
||||
[](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"is_final",
|
||||
[](PyClass &self) -> bool { return self.is_final; })
|
||||
.def("as_json_string", &PyClass::AsJsonString,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
|
||||
@@ -503,6 +503,9 @@ class OnlineRecognizer(object):
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).text.strip()
|
||||
|
||||
def get_result_as_json_string(self, s: OnlineStream) -> str:
|
||||
return self.recognizer.get_result(s).as_json_string()
|
||||
|
||||
def tokens(self, s: OnlineStream) -> List[str]:
|
||||
return self.recognizer.get_result(s).tokens
|
||||
|
||||
@@ -512,6 +515,15 @@ class OnlineRecognizer(object):
|
||||
def start_time(self, s: OnlineStream) -> float:
|
||||
return self.recognizer.get_result(s).start_time
|
||||
|
||||
def ys_probs(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).ys_probs
|
||||
|
||||
def lm_probs(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).lm_probs
|
||||
|
||||
def context_scores(self, s: OnlineStream) -> List[float]:
|
||||
return self.recognizer.get_result(s).context_scores
|
||||
|
||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||
return self.recognizer.is_endpoint(s)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user