online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)
* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) - added `reset_encoder` boolean member into the OnlineRecognizerConfig class - by default the encoder is not reset * pybind11, adding empty symbols for disabled modules (tts, diarization) * reset_encoder, add default value (false) [pybind11]
This commit is contained in:
@@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
// reset encoder states
|
||||
// s->SetStates(model_->GetEncoderInitStates());
|
||||
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
auto last_result = s->GetResult();
|
||||
// if last result is not empty, then
|
||||
// truncate all last hyps and save as the context for next result
|
||||
|
||||
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
|
||||
// if last result is not empty, then
|
||||
// truncate all last hyps and save as the 'ys' context for next result
|
||||
// (the encoder state buffers are kept)
|
||||
for (const auto &it : last_result.hyps) {
|
||||
auto h = it.second;
|
||||
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
|
||||
@@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
|
||||
last_result.tokens.end());
|
||||
} else {
|
||||
if(config_.reset_encoder) {
|
||||
// reset encoder states, use blanks as 'ys' context
|
||||
s->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
}
|
||||
|
||||
// but reset all contextual biasing graph states to root
|
||||
|
||||
Reference in New Issue
Block a user