decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
fprintf(stderr,
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
|
||||
Reference in New Issue
Block a user