Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -13,6 +13,29 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<int32_t> &hyps_num_split,
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
int32_t context_size, Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
float *dst = decoder_out->GetTensorMutableData<float>();
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
|
||||
if (num_hyps > 1 || !results[i].decoder_out) {
|
||||
dst += num_hyps * shape[1];
|
||||
continue;
|
||||
}
|
||||
|
||||
const float *src = results[i].decoder_out.GetTensorData<float>();
|
||||
std::copy(src, src + shape[1], dst);
|
||||
dst += shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResult r;
|
||||
std::vector<int32_t> blanks(context_size, blank_id);
|
||||
std::vector<int64_t> blanks(context_size, blank_id);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
r.hyps = std::move(blank_hyp);
|
||||
return r;
|
||||
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
if (t == 0) {
|
||||
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
|
||||
&decoder_out);
|
||||
}
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
(*result)[b].hyps = std::move(cur[b]);
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
(*result)[b].hyps = std::move(hyps);
|
||||
(*result)[b].tokens = std::move(best_hyp.ys);
|
||||
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
|
||||
OnlineTransducerDecoderResult *result) {
|
||||
if (result->tokens.size() == model_->ContextSize()) {
|
||||
result->decoder_out = Ort::Value{nullptr};
|
||||
return;
|
||||
}
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
|
||||
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user