Support RKNN for Zipformer CTC models. (#1948)
This commit is contained in:
@@ -24,12 +24,11 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor, int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
@@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||
log_probs_shape[1], log_probs_shape[2], &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(r.text);
|
||||
return r;
|
||||
}
|
||||
@@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
out[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
decoder_->Decode(out[0].GetTensorData<float>(), log_probs_shape[0],
|
||||
log_probs_shape[1], log_probs_shape[2], &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user