Add PackPaddedSequence (#85)

This commit is contained in:
Fangjun Kuang
2023-03-08 14:12:20 +08:00
committed by GitHub
parent 3a79115884
commit 8c6a289e3d
7 changed files with 236 additions and 21 deletions

View File

@@ -6,29 +6,48 @@
#include <assert.h>
#include <algorithm>
#include <vector>
namespace sherpa_onnx {
template <typename T /*=float*/>
Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start,
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
int32_t dim1_end) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
assert(shape.size() == 3);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
assert(0 <= dim0_start);
assert(dim0_start < dim0_end);
assert(dim0_end <= shape[0]);
assert(0 <= dim1_start);
assert(dim1_start < dim1_end);
assert(dim1_end < shape[1]);
std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]};
const T *src = v->GetTensorData<T>();
src += dim0 * shape[1] * shape[2] + dim1_start * shape[2];
return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src),
ans_shape[0] * ans_shape[1], ans_shape.data(),
ans_shape.size());
std::array<int64_t, 3> ans_shape{dim0_end - dim0_start, dim1_end - dim1_start,
shape[2]};
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
for (int32_t i = dim0_start; i != dim0_end; ++i) {
const T *src = v->GetTensorData<T>() + i * shape[1] * shape[2];
const T *start = src + dim1_start * shape[2];
const T *end = src + dim1_end * shape[2];
std::copy(start, end, dst);
dst += ans_shape[1] * ans_shape[2];
}
return ans;
}
template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0,
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
int32_t dim0_start, int32_t dim0_end,
int32_t dim1_start, int32_t dim1_end);
} // namespace sherpa_onnx