Add timestamps for streaming ASR. (#123)

This commit is contained in:
Fangjun Kuang
2023-04-19 16:02:37 +08:00
committed by GitHub
parent 4b5d2887cb
commit ad05f52666
11 changed files with 170 additions and 19 deletions

View File

@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
r->num_trailing_blanks = hyp.num_trailing_blanks;
}
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float *p_logit = logit.GetTensorMutableData<float>();
for (int32_t b = 0; b < batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_num_split[b];
int32_t end = hyps_num_split[b + 1];
LogSoftmax(p_logit, vocab_size, (end - start));
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index];
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
} else {
++new_hyp.num_trailing_blanks;
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
auto &r = (*result)[b];
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
r.hyps = std::move(hyps);
r.tokens = std::move(best_hyp.ys);
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
r.frame_offset += num_frames;
}
}