Add PackPaddedSequence (#85)
This commit is contained in:
@@ -6,29 +6,48 @@
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
template <typename T /*=float*/>
|
||||
Ort::Value Slice(const Ort::Value *v, int32_t dim0, int32_t dim1_start,
|
||||
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
|
||||
int32_t dim1_end) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(shape.size() == 3);
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
assert(0 <= dim0_start);
|
||||
assert(dim0_start < dim0_end);
|
||||
assert(dim0_end <= shape[0]);
|
||||
|
||||
assert(0 <= dim1_start);
|
||||
assert(dim1_start < dim1_end);
|
||||
assert(dim1_end < shape[1]);
|
||||
|
||||
std::array<int64_t, 2> ans_shape{dim1_end - dim1_start, shape[2]};
|
||||
const T *src = v->GetTensorData<T>();
|
||||
src += dim0 * shape[1] * shape[2] + dim1_start * shape[2];
|
||||
|
||||
return Ort::Value::CreateTensor(memory_info, const_cast<T *>(src),
|
||||
ans_shape[0] * ans_shape[1], ans_shape.data(),
|
||||
ans_shape.size());
|
||||
std::array<int64_t, 3> ans_shape{dim0_end - dim0_start, dim1_end - dim1_start,
|
||||
shape[2]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
for (int32_t i = dim0_start; i != dim0_end; ++i) {
|
||||
const T *src = v->GetTensorData<T>() + i * shape[1] * shape[2];
|
||||
const T *start = src + dim1_start * shape[2];
|
||||
const T *end = src + dim1_end * shape[2];
|
||||
|
||||
std::copy(start, end, dst);
|
||||
dst += ans_shape[1] * ans_shape[2];
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Slice<float>(const Ort::Value *v, int32_t dim0,
|
||||
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end,
|
||||
int32_t dim1_start, int32_t dim1_end);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user