add modified beam search (#69)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// sherpa-onnx/csrc/online-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -8,6 +9,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -75,6 +77,40 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Ort::Value OnlineTransducerModel::BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) {
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
int32_t context_size = ContextSize();
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (const auto &r : results) {
|
||||
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size;
|
||||
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value OnlineTransducerModel::BuildDecoderInput(
|
||||
const std::vector<Hypothesis> &hyps) {
|
||||
int32_t batch_size = static_cast<int32_t>(hyps.size());
|
||||
int32_t context_size = ContextSize();
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (const auto &h : hyps) {
|
||||
std::copy(h.ys.end() - context_size, h.ys.end(), p);
|
||||
p += context_size;
|
||||
}
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
|
||||
|
||||
Reference in New Issue
Block a user