This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/packed-sequence-test.cc
2023-03-08 14:12:20 +08:00

53 lines
1.4 KiB
C++

// sherpa-onnx/csrc/packed-sequence-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/packed-sequence.h"
#include <numeric>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
TEST(PackedSequence, Case1) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{5, 5, 4};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
Ort::Value length =
Ort::Value::CreateTensor<int64_t>(allocator, shape.data(), 1);
int64_t *p_length = length.GetTensorMutableData<int64_t>();
p_length[0] = 1;
p_length[1] = 2;
p_length[2] = 3;
p_length[3] = 5;
p_length[4] = 2;
auto packed_seq = PackPaddedSequence(allocator, &v, &length);
fprintf(stderr, "sorted indexes: ");
for (auto i : packed_seq.sorted_indexes) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
// output index: 0 1 2 3 4
// sorted indexes: 3 2 1 4 0
// length: 5 3 2 2 1
Print3D(&v);
Print2D(&packed_seq.data);
fprintf(stderr, "batch sizes per time step: ");
for (auto i : packed_seq.batch_sizes) {
fprintf(stderr, "%d ", static_cast<int32_t>(i));
}
fprintf(stderr, "\n");
// TODO(fangjun): Check that the return value is correct
}
} // namespace sherpa_onnx