Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
build
|
build
|
||||||
*.zip
|
*.zip
|
||||||
*.tgz
|
*.tgz
|
||||||
|
*.sw?
|
||||||
onnxruntime-*
|
onnxruntime-*
|
||||||
icefall-*
|
icefall-*
|
||||||
run.sh
|
run.sh
|
||||||
|
|||||||
@@ -29,9 +29,21 @@ struct Hypothesis {
|
|||||||
std::vector<int32_t> timestamps;
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
// The acoustic probability for each token in ys.
|
// The acoustic probability for each token in ys.
|
||||||
// Only used for keyword spotting task.
|
// Used for keyword spotting task.
|
||||||
|
// For transducer mofified beam-search and greedy-search,
|
||||||
|
// this is filled with log_posterior scores.
|
||||||
std::vector<float> ys_probs;
|
std::vector<float> ys_probs;
|
||||||
|
|
||||||
|
// lm_probs[i] contains the lm score for each token in ys.
|
||||||
|
// Used only in transducer mofified beam-search.
|
||||||
|
// Elements filled only if LM is used.
|
||||||
|
std::vector<float> lm_probs;
|
||||||
|
|
||||||
|
// context_scores[i] contains the context-graph score for each token in ys.
|
||||||
|
// Used only in transducer mofified beam-search.
|
||||||
|
// Elements filled only if `ContextGraph` is used.
|
||||||
|
std::vector<float> context_scores;
|
||||||
|
|
||||||
// The total score of ys in log space.
|
// The total score of ys in log space.
|
||||||
// It contains only acoustic scores
|
// It contains only acoustic scores
|
||||||
double log_prob = 0;
|
double log_prob = 0;
|
||||||
|
|||||||
@@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
|||||||
r.timestamps.push_back(time);
|
r.timestamps.push_back(time);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.ys_probs = std::move(src.ys_probs);
|
||||||
|
r.lm_probs = std::move(src.lm_probs);
|
||||||
|
r.context_scores = std::move(src.context_scores);
|
||||||
|
|
||||||
r.segment = segment;
|
r.segment = segment;
|
||||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||||
|
|
||||||
|
|||||||
@@ -18,56 +18,50 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||||
|
template<typename T>
|
||||||
|
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::fixed << std::setprecision(precision);
|
||||||
|
oss << "[ ";
|
||||||
|
std::string sep = "";
|
||||||
|
for (const auto& item : vec) {
|
||||||
|
oss << sep << item;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
oss << " ]";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||||
|
template<> // explicit specialization for T = std::string
|
||||||
|
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||||
|
int32_t) { // ignore 2nd arg
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[ ";
|
||||||
|
std::string sep = "";
|
||||||
|
for (const auto& item : vec) {
|
||||||
|
oss << sep << "\"" << item << "\"";
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
oss << " ]";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "{";
|
os << "{ ";
|
||||||
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
|
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||||
os << "\"segment\":" << segment << ", ";
|
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||||
<< ", ";
|
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||||
|
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||||
os << "\"text\""
|
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||||
<< ": ";
|
os << "\"segment\": " << segment << ", ";
|
||||||
os << "\"" << text << "\""
|
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||||
<< ", ";
|
<< start_time << ", ";
|
||||||
|
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||||
os << "\""
|
|
||||||
<< "timestamps"
|
|
||||||
<< "\""
|
|
||||||
<< ": ";
|
|
||||||
os << "[";
|
|
||||||
|
|
||||||
std::string sep = "";
|
|
||||||
for (auto t : timestamps) {
|
|
||||||
os << sep << std::fixed << std::setprecision(2) << t;
|
|
||||||
sep = ", ";
|
|
||||||
}
|
|
||||||
os << "], ";
|
|
||||||
|
|
||||||
os << "\""
|
|
||||||
<< "tokens"
|
|
||||||
<< "\""
|
|
||||||
<< ":";
|
|
||||||
os << "[";
|
|
||||||
|
|
||||||
sep = "";
|
|
||||||
auto oldFlags = os.flags();
|
|
||||||
for (const auto &t : tokens) {
|
|
||||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
|
||||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
|
||||||
os << sep << "\""
|
|
||||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
|
||||||
<< ">"
|
|
||||||
<< "\"";
|
|
||||||
os.flags(oldFlags);
|
|
||||||
} else {
|
|
||||||
os << sep << "\"" << t << "\"";
|
|
||||||
}
|
|
||||||
sep = ", ";
|
|
||||||
}
|
|
||||||
os << "]";
|
|
||||||
os << "}";
|
os << "}";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
|
|||||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||||
std::vector<float> timestamps;
|
std::vector<float> timestamps;
|
||||||
|
|
||||||
|
std::vector<float> ys_probs; //< log-prob scores from ASR model
|
||||||
|
std::vector<float> lm_probs; //< log-prob scores from language model
|
||||||
|
//
|
||||||
|
/// log-domain scores from "hot-phrase" contextual boosting
|
||||||
|
std::vector<float> context_scores;
|
||||||
|
|
||||||
/// ID of this segment
|
/// ID of this segment
|
||||||
/// When an endpoint is detected, it is incremented
|
/// When an endpoint is detected, it is incremented
|
||||||
int32_t segment = 0;
|
int32_t segment = 0;
|
||||||
@@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
|
|||||||
* "text": "The recognition result",
|
* "text": "The recognition result",
|
||||||
* "tokens": [x, x, x],
|
* "tokens": [x, x, x],
|
||||||
* "timestamps": [x, x, x],
|
* "timestamps": [x, x, x],
|
||||||
|
* "ys_probs": [x, x, x],
|
||||||
|
* "lm_probs": [x, x, x],
|
||||||
|
* "context_scores": [x, x, x],
|
||||||
* "segment": x,
|
* "segment": x,
|
||||||
* "start_time": x,
|
* "start_time": x,
|
||||||
* "is_final": true|false
|
* "is_final": true|false
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
|||||||
frame_offset = other.frame_offset;
|
frame_offset = other.frame_offset;
|
||||||
timestamps = other.timestamps;
|
timestamps = other.timestamps;
|
||||||
|
|
||||||
|
ys_probs = other.ys_probs;
|
||||||
|
lm_probs = other.lm_probs;
|
||||||
|
context_scores = other.context_scores;
|
||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
|||||||
frame_offset = other.frame_offset;
|
frame_offset = other.frame_offset;
|
||||||
timestamps = std::move(other.timestamps);
|
timestamps = std::move(other.timestamps);
|
||||||
|
|
||||||
|
ys_probs = std::move(other.ys_probs);
|
||||||
|
lm_probs = std::move(other.lm_probs);
|
||||||
|
context_scores = std::move(other.context_scores);
|
||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
|
|||||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||||
std::vector<int32_t> timestamps;
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
|
std::vector<float> ys_probs;
|
||||||
|
std::vector<float> lm_probs;
|
||||||
|
std::vector<float> context_scores;
|
||||||
|
|
||||||
// Cache decoder_out for endpointing
|
// Cache decoder_out for endpointing
|
||||||
Ort::Value decoder_out;
|
Ort::Value decoder_out;
|
||||||
|
|
||||||
|
|||||||
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
|||||||
r->tokens = std::vector<int64_t>(start, end);
|
r->tokens = std::vector<int64_t>(start, end);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void OnlineTransducerGreedySearchDecoder::Decode(
|
void OnlineTransducerGreedySearchDecoder::Decode(
|
||||||
Ort::Value encoder_out,
|
Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
|
|
||||||
std::vector<int64_t> encoder_out_shape =
|
std::vector<int64_t> encoder_out_shape =
|
||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_batch_decoder_out_cached) {
|
if (is_batch_decoder_out_cached) {
|
||||||
auto &r = result->front();
|
auto &r = result->front();
|
||||||
std::vector<int64_t> decoder_out_shape =
|
std::vector<int64_t> decoder_out_shape =
|
||||||
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
if (blank_penalty_ > 0.0) {
|
if (blank_penalty_ > 0.0) {
|
||||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||||
}
|
}
|
||||||
|
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
} else {
|
} else {
|
||||||
++r.num_trailing_blanks;
|
++r.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// export the per-token log scores
|
||||||
|
if (y != 0 && y != unk_id_) {
|
||||||
|
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||||
|
// save time by doing it only for
|
||||||
|
// emitted symbols
|
||||||
|
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
|
||||||
|
// now it contains normalized
|
||||||
|
// probability
|
||||||
|
r.ys_probs.push_back(p_logprob[y]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (emitted) {
|
if (emitted) {
|
||||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
|||||||
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
||||||
r->tokens = std::move(tokens);
|
r->tokens = std::move(tokens);
|
||||||
r->timestamps = std::move(hyp.timestamps);
|
r->timestamps = std::move(hyp.timestamps);
|
||||||
|
|
||||||
|
// export per-token scores
|
||||||
|
r->ys_probs = std::move(hyp.ys_probs);
|
||||||
|
r->lm_probs = std::move(hyp.lm_probs);
|
||||||
|
r->context_scores = std::move(hyp.context_scores);
|
||||||
|
|
||||||
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||||
prev_lm_log_prob; // log_prob only includes the
|
prev_lm_log_prob; // log_prob only includes the
|
||||||
// score of the transducer
|
// score of the transducer
|
||||||
|
// export the per-token log scores
|
||||||
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
|
const Hypothesis& prev_i = prev[hyp_index];
|
||||||
|
// subtract 'prev[i]' path scores, which were added before
|
||||||
|
// getting topk tokens
|
||||||
|
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||||
|
new_hyp.ys_probs.push_back(y_prob);
|
||||||
|
|
||||||
|
if (lm_) { // export only when LM is used
|
||||||
|
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
||||||
|
if (lm_scale_ != 0.0) {
|
||||||
|
lm_prob /= lm_scale_; // remove lm-scale
|
||||||
|
}
|
||||||
|
new_hyp.lm_probs.push_back(lm_prob);
|
||||||
|
}
|
||||||
|
|
||||||
|
// export only when `ContextGraph` is used
|
||||||
|
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||||
|
new_hyp.context_scores.push_back(context_score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hyps.Add(std::move(new_hyp));
|
hyps.Add(std::move(new_hyp));
|
||||||
} // for (auto k : topk)
|
} // for (auto k : topk)
|
||||||
cur.push_back(std::move(hyps));
|
cur.push_back(std::move(hyps));
|
||||||
|
|||||||
@@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
|||||||
[](PyClass &self) -> float { return self.start_time; })
|
[](PyClass &self) -> float { return self.start_time; })
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"timestamps",
|
"timestamps",
|
||||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
|
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"ys_probs",
|
||||||
|
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"lm_probs",
|
||||||
|
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"context_scores",
|
||||||
|
[](PyClass &self) -> std::vector<float> {
|
||||||
|
return self.context_scores;
|
||||||
|
})
|
||||||
|
.def_property_readonly(
|
||||||
|
"segment",
|
||||||
|
[](PyClass &self) -> int32_t { return self.segment; })
|
||||||
|
.def_property_readonly(
|
||||||
|
"is_final",
|
||||||
|
[](PyClass &self) -> bool { return self.is_final; })
|
||||||
|
.def("as_json_string", &PyClass::AsJsonString,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||||
|
|||||||
@@ -503,6 +503,9 @@ class OnlineRecognizer(object):
|
|||||||
def get_result(self, s: OnlineStream) -> str:
|
def get_result(self, s: OnlineStream) -> str:
|
||||||
return self.recognizer.get_result(s).text.strip()
|
return self.recognizer.get_result(s).text.strip()
|
||||||
|
|
||||||
|
def get_result_as_json_string(self, s: OnlineStream) -> str:
|
||||||
|
return self.recognizer.get_result(s).as_json_string()
|
||||||
|
|
||||||
def tokens(self, s: OnlineStream) -> List[str]:
|
def tokens(self, s: OnlineStream) -> List[str]:
|
||||||
return self.recognizer.get_result(s).tokens
|
return self.recognizer.get_result(s).tokens
|
||||||
|
|
||||||
@@ -512,6 +515,15 @@ class OnlineRecognizer(object):
|
|||||||
def start_time(self, s: OnlineStream) -> float:
|
def start_time(self, s: OnlineStream) -> float:
|
||||||
return self.recognizer.get_result(s).start_time
|
return self.recognizer.get_result(s).start_time
|
||||||
|
|
||||||
|
def ys_probs(self, s: OnlineStream) -> List[float]:
|
||||||
|
return self.recognizer.get_result(s).ys_probs
|
||||||
|
|
||||||
|
def lm_probs(self, s: OnlineStream) -> List[float]:
|
||||||
|
return self.recognizer.get_result(s).lm_probs
|
||||||
|
|
||||||
|
def context_scores(self, s: OnlineStream) -> List[float]:
|
||||||
|
return self.recognizer.get_result(s).context_scores
|
||||||
|
|
||||||
def is_endpoint(self, s: OnlineStream) -> bool:
|
def is_endpoint(self, s: OnlineStream) -> bool:
|
||||||
return self.recognizer.is_endpoint(s)
|
return self.recognizer.is_endpoint(s)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user