Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -13,6 +13,29 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<int32_t> &hyps_num_split,
const std::vector<OnlineTransducerDecoderResult> &results,
int32_t context_size, Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(results.size());
for (int32_t i = 0; i != batch_size; ++i) {
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
if (num_hyps > 1 || !results[i].decoder_out) {
dst += num_hyps * shape[1];
continue;
}
const float *src = results[i].decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
dst += shape[1];
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
std::vector<int32_t> blanks(context_size, blank_id);
std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
if (t == 0) {
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
&decoder_out);
}
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
for (int32_t b = 0; b != batch_size; ++b) {
(*result)[b].hyps = std::move(cur[b]);
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
}
}
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
OnlineTransducerDecoderResult *result) {
if (result->tokens.size() == model_->ContextSize()) {
result->decoder_out = Ort::Value{nullptr};
return;
}
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
}
} // namespace sherpa_onnx