Preserve more context after endpointing in transducer (#2061)
This commit is contained in:
committed by
GitHub
parent
da4aad1189
commit
18a6ed5ddc
@@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
auto r = decoder_->GetEmptyResult();
|
auto r = decoder_->GetEmptyResult();
|
||||||
auto last_result = s->GetResult();
|
auto last_result = s->GetResult();
|
||||||
// if last result is not empty, then
|
// if last result is not empty, then
|
||||||
// preserve last tokens as the context for next result
|
// truncate all last hyps and save as the context for next result
|
||||||
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
|
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
|
||||||
std::vector<int64_t> context(last_result.tokens.end() - context_size,
|
for (const auto &it : last_result.hyps) {
|
||||||
last_result.tokens.end());
|
auto h = it.second;
|
||||||
|
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
|
||||||
|
h.ys.end()),
|
||||||
|
h.log_prob});
|
||||||
|
}
|
||||||
|
|
||||||
Hypotheses context_hyp({{context, 0}});
|
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
|
||||||
r.hyps = std::move(context_hyp);
|
last_result.tokens.end());
|
||||||
r.tokens = std::move(context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// but reset all contextual biasing graph states to root
|
||||||
if (config_.decoding_method == "modified_beam_search" &&
|
if (config_.decoding_method == "modified_beam_search" &&
|
||||||
nullptr != s->GetContextGraph()) {
|
nullptr != s->GetContextGraph()) {
|
||||||
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||||
|
|||||||
Reference in New Issue
Block a user