This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
cmake_minimum_required(VERSION 3.10)
set(CMAKE_C_COMPILER "gcc")
set(CMAKE_CXX_COMPILER "g++")
project(kernel_test)
message(STATUS "project name: ${PROJECT_NAME}")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/archive")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake/modules")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_compile_options(-std=c++17 -O3 -g -fPIC -Wall -Werror -Wextra -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unknown-pragmas)
link_directories($ENV{NEUWARE_HOME}/lib64)
link_libraries(m stdc++ dl pthread cnnl cnnl_extra cnrt cndrv "-Wl,-rpath,$ENV{NEUWARE_HOME}/lib64 -Wl,--disable-new-dtags")
include_directories($ENV{NEUWARE_HOME}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/ ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/kernels)
function(find_torch)
# get the path and cxx11_abi flag
execute_process(
COMMAND python3 -c "import torch; print(torch.__path__[0], torch.compiled_with_cxx11_abi(), sep=';')"
RESULT_VARIABLE TORCH_NOT_FOUND
OUTPUT_VARIABLE TORCH_INFO
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_NOT_FOUND)
return()
endif()
list(GET TORCH_INFO 0 TORCH_PATH)
message(STATUS "torch path: ${TORCH_PATH}")
list(GET TORCH_INFO 1 TORCH_CXX11_ABI)
message(STATUS "torch cxx11 abi: ${TORCH_CXX11_ABI}")
set(Torch_DIR ${TORCH_PATH}/share/cmake/Torch PARENT_SCOPE)
endfunction()
# import pytorch
find_torch()
message(STATUS "Torch_DIR: ${Torch_DIR}")
find_package(Torch QUIET)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
# import torch_mlu
execute_process(
COMMAND python3 -c "import torch_mlu.utils as mlu_utils;print(mlu_utils.cmake_prefix_path)"
OUTPUT_VARIABLE Torch_MLU_MODULE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# find ops library
execute_process(
COMMAND python3 -c "import torch_mlu_ops as ops;print(ops._utils.get_custom_op_library_path())"
RESULT_VARIABLE LIBOPS_NOT_FOUND
OUTPUT_VARIABLE LIBOPS_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(LIBOPS_NOT_FOUND)
message(FATAL_ERROR "torch_mlu_ops not installed, can not find ops library.")
endif()
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${Torch_MLU_MODULE_DIR})
find_package(TorchMLU QUIET)
# torch_mlu throw [-Werror=sign-compare] error, so ignore it
add_compile_options(-Wno-sign-compare)
# TorchMLUConfig.cmake run will get TORCH_ATEN_LIBRARY-NOTFOUND, it will cause compile fail, remove it
string(REPLACE "TORCH_ATEN_LIBRARY-NOTFOUND" "" TORCH_MLU_LIBRARIES_MODIFIED "${TORCH_MLU_LIBRARIES}")
execute_process(
COMMAND python3 -c "from distutils import sysconfig; print(sysconfig.get_python_inc())"
OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(Torch_FOUND AND ((TORCH_CXX11_ABI AND USE_CXX11_ABI) OR (NOT TORCH_CXX11_ABI AND NOT USE_CXX11_ABI)))
include_directories(${PYTHON_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} ${TORCH_MLU_INCLUDE_DIRS})
link_libraries(${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} ${TORCH_MLU_LIBRARIES_MODIFIED})
file(GLOB_RECURSE TEST_SRCS "src/*.cpp" RECURSE)
execute_process(
COMMAND python3 -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))"
OUTPUT_VARIABLE TEST_LIB_NAME_SUFFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message("${PYTHON_EXTENSION}")
set(CMAKE_SHARED_LIBRARY_PREFIX "")
set(CMAKE_SHARED_LIBRARY_SUFFIX "")
set(TEST_LIB_NAME_PREFIX "btunittests")
string(APPEND TEST_LIB_NAME "${TEST_LIB_NAME_PREFIX}${TEST_LIB_NAME_SUFFIX}")
add_library(${TEST_LIB_NAME} SHARED ${TEST_SRCS})
target_link_libraries(${TEST_LIB_NAME} ${LIBOPS_PATH} -lstdc++fs)
else()
message(STATUS "Torch not found, or torch abi is different with which you specified, will not build")
message(STATUS "if torch not found, please set env Torch_DIR to the directory containing TorchConfig.cmake")
endif()

View File

@@ -0,0 +1,17 @@
## BT_OPS测试脚本使用方式
```bash
# 测试所有测例
bash run_test.sh
```
```bash
# 测试单个测例
python3 test_测例名称.py
```
- 必须在Torch-MLU-Ops docker容器内运行。
- 测试脚本的命名规则为 `test_测例名称.py`
- 必须保证 Torch-MLU-Ops whl包正确安装。

View File

@@ -0,0 +1,10 @@
#!/bin/bash
set -e
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${TOP_DIR}
rm -rf build > /dev/null 2>&1
cmake $TOP_DIR -Bbuild $@
cmake --build build -- -j32

View File

@@ -0,0 +1,22 @@
#!/bin/bash
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
tmo_kernel_case=$(find "${SCRIPT_DIR}" -name "test_*.py")
coverage=${1}
for sc in ${tmo_kernel_case}
do
echo -n "${sc} "
echo -n "Testing...."
if [ "${coverage}" = "coverage" ];then
coverage run -a ${sc}
else
python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1
fi
if [ $? == 0 ];then
echo -e "\033[32m success \033[0m"
else
echo -e "\033[31m failed \033[0m"
fi
done
echo "End of pytest..."

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/generate_alibi_slope.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length) {
MLU_TENSOR_CHECK_FATAL(true_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
int head_num = tp_head_num * tp_num;
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(true_seq_lens.device());
torch::Tensor output = torch::zeros(
{batch, head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
for (int tp_id = 0; tp_id < tp_num; tp_id++) {
torch::Tensor tp_slopes = torch::zeros(
{batch, tp_head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeGenerateAlibiSlope(
queue, tp_slopes.data_ptr(), true_seq_lens.data_ptr(), batch, tp_id * tp_head_num,
tp_head_num, head_num, max_sequence_length, use_dynamic));
for (int batch_id = 0; batch_id < batch; batch_id++) {
CNRT_CHECK(
cnrtMemcpyAsync((float *)output.data_ptr() + batch_id * head_num + tp_id * tp_head_num,
(float *)tp_slopes.data_ptr() + batch_id * tp_head_num,
tp_head_num * sizeof(float), queue, cnrtMemcpyDevToDev));
}
}
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,83 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/create_cos_sin_table.mluh"
#include <cnnl.h>
#include <torch/extension.h>
#include "register_api.h"
#include "utils.h"
namespace {
constexpr int DYNAMIC_NTK_SCALING = 2;
}
namespace tmo {
namespace kernel_test_api {
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved) {
MLU_TENSOR_CHECK_FATAL(rotary_emb_alpha_cached);
MLU_TENSOR_CHECK_FATAL(seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dtype =
torch2cnnlDataType(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()));
std::shared_ptr<torch::Device> dev =
std::make_shared<torch::Device>(rotary_emb_alpha_cached.device());
int batch = seq_lens.size(0);
torch::Tensor cos_sin_table =
rotary_scaling_type == DYNAMIC_NTK_SCALING
? torch::zeros({batch, rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false))
: torch::zeros({rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false));
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeCreateCosSinTable(
queue, cos_sin_table.data_ptr(),
reinterpret_cast<float *>(rotary_emb_alpha_cached.data_ptr()),
reinterpret_cast<int *>(seq_lens.data_ptr()), max_position_embeddings, batch, batch_stride,
rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, rotary_scaling_type,
interleaved, dtype));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return cos_sin_table;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,48 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(weight);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
int batch = input.size(0);
int seq = input.size(1);
int total_vocab_size = weight.size(0);
int hidden_size = weight.size(1);
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(weight.dtype()));
torch::Tensor output = torch::zeros(
{batch, seq, hidden_size},
torch::dtype(torch::typeMetaToScalarType(weight.dtype())).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeEmbedding(queue, weight.data_ptr(), input.data_ptr(),
output.data_ptr(), dtype, vocab_offset, vocab_part,
total_vocab_size, hidden_size, batch * seq));
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/operate_cu_seq_lens.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num) {
MLU_TENSOR_CHECK_FATAL(cu_seq_lens);
MLU_TENSOR_CHECK_FATAL(sliced_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeSliceCuSeqlens(queue, (int *)cu_seq_lens.data_ptr(),
(int *)sliced_cu_seq_lens.data_ptr(), batch,
parallel_num));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len) {
MLU_TENSOR_CHECK_FATAL(gen_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeGenerateCuSeqlens(queue, (int *)gen_cu_seq_lens.data_ptr(), seq_len,
parallel_num, is_causal_mask, is_kv_seq_len));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,67 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/quantize.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride) {
MLU_TENSOR_CHECK_FATAL(dst);
MLU_TENSOR_CHECK_FATAL(scale);
MLU_TENSOR_CHECK_FATAL(src);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dst_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(dst.dtype()));
cnnlDataType_t scale_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(scale.dtype()));
cnnlDataType_t src_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(src.dtype()));
size_t io_bytes =
bs * seq_len * head_num * head_size * (dtype_size_map[src_dtype] + dtype_size_map[dst_dtype]);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeMluQuantizePerHead(
queue, (void *)dst.data_ptr(), (void *)scale.data_ptr(), (void *)src.data_ptr(), dst_dtype,
scale_dtype, src_dtype, bs, seq_len, head_num, head_size, dst_bs_stride, dst_seq_stride,
dst_head_stride, src_bs_stride, src_seq_stride, src_head_stride));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,82 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_REGISTER_API_H_
#define TEST_KERNELS_PYTEST_REGISTER_API_H_
#include <torch/extension.h>
#include <torch/torch.h>
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part);
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length);
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved);
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num);
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len);
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride);
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset);
} // namespace kernel_test_api
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_REGISTER_API_H_

View File

@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/extension.h>
#include "register_api.h"
PYBIND11_MODULE(btunittests, m) {
m.def("embedding_test", &tmo::kernel_test_api::embedding_test, "Test embedding kernel function.");
m.def("alibi_slope_test", &tmo::kernel_test_api::alibi_slope_test,
"Test alibi_slope kernel function.");
m.def("create_cos_sin_table_test", &tmo::kernel_test_api::create_cos_sin_table_test,
"Test create_cos_sin_table_test kernel function.");
m.def("slice_cu_seq_lens_test", &tmo::kernel_test_api::slice_cu_seq_lens_test,
"Test slice_cu_seq_lens_test kernel function.");
m.def("generate_cu_seq_lens_test", &tmo::kernel_test_api::generate_cu_seq_lens_test,
"Test generate_cu_seq_lens_test kernel function.");
m.def("per_head_quantize_test", &tmo::kernel_test_api::per_head_quantize_test,
"Test per_head_quantize_test kernel function.");
m.def("rotary_embedding_test", &tmo::kernel_test_api::rotary_embedding_test,
"Test rotary_embedding kernel function.");
}

View File

@@ -0,0 +1,202 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/rotary_embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(sin_table);
MLU_TENSOR_CHECK_FATAL(cos_table);
bool has_seq_offsets = seq_offsets.has_value();
bool has_cu_seq_lens = cu_seq_lens.has_value();
int *seq_offsets_ptr =
has_seq_offsets ? reinterpret_cast<int *>(seq_offsets.value().data_ptr()) : nullptr;
int *cu_seq_lens_ptr =
has_cu_seq_lens ? reinterpret_cast<int *>(cu_seq_lens.value().data_ptr()) : nullptr;
int batch = 0;
int total_seq_len = 0;
int head_size = input.size(-1);
if (input.dim() == 3) { // pack mode
TORCH_CHECK(has_cu_seq_lens,
"input has 3 dims: (total_seq_len, head_num, head_size),"
" which means pack mode, cu_seqlens should not be None");
total_seq_len = input.size(0);
batch = cu_seq_lens.value().size(0) - 1;
} else if (input.dim() == 4) {
TORCH_CHECK(!has_cu_seq_lens,
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, cu_seqlens should be None");
TORCH_CHECK(max_seq_len == input.size(1),
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, max_seqlen must be equtals to input.size(1)");
batch = input.size(0);
total_seq_len = batch * input.size(1);
} else {
TORCH_CHECK(false, "input only support 3 or 4 dims");
}
const int rope_seqlen = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
const int rope_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (has_seq_offsets) {
if (discrete) {
CHECK_SHAPE(seq_offsets.value(), total_seq_len);
} else {
CHECK_SHAPE(seq_offsets.value(), batch);
}
}
if (dynamic_ntk) {
CHECK_SHAPE(sin_table, batch, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, batch, rope_seqlen, rope_dim);
} else {
CHECK_SHAPE(sin_table, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, rope_seqlen, rope_dim);
}
#endif
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
if (dynamic_ntk) {
TORCH_CHECK(sin_table.stride(1) == cos_table.stride(1),
"sin_table second stride must be equal to cos_table second stride");
} else {
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
"sin_table first stride must be equal to cos_table second stride");
}
if (has_seq_offsets) {
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
}
if (has_cu_seq_lens) {
TORCH_CHECK(cu_seq_lens.value().is_contiguous(), "cu_seq_lens must be contiguous");
}
// auto output = input;
// if (is_qkv) {
// }
int dims = input.dim();
int head_num = input.size(dims - 2); // head_num_qk
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
int head_dim = input.size(dims - 1);
int rotary_seq_len = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
int rotary_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (rope_2d) {
total_seq_len = input.size(0);
rotary_seq_len *= 2;
}
#endif
int rotary_stride = dynamic_ntk ? sin_table.stride(1) : sin_table.stride(0);
size_t input_seq_stride = input.stride(dims - 3);
size_t input_head_stride = input.stride(dims - 2);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
#if 0
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
std::cout
<< " batch * max_seq_len : " << batch * max_seq_len
<< " max_seq_len : " << max_seq_len
<< " total_seq_len : " << total_seq_len
<< " head_num : " << head_num
<< " head_size : " << head_size << "\n";
// at::IntArrayRef output_shape = {batch, max_seq_len, head_num, head_size};
// if (input.dim() == 3) {
// output_shape = {batch * max_seq_len, head_num, head_size};
// }
#endif
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(input.dtype()));
auto output = input;
size_t output_seq_stride = output.stride(dims - 3);
size_t output_head_stride = output.stride(dims - 2);
#if 0
std::cout << " <<<< batch: " << batch << "\n";
std::cout << " <<<< max_seq_len: " << max_seq_len << "\n";
std::cout << " <<<< total_seq_len: " << total_seq_len << "\n";
std::cout << " <<<< head_num: " << head_num << "\n";
std::cout << " <<<< head_size: " << head_size << "\n";
std::cout << " <<<< rotary_seq_len_: " << rotary_seq_len << "\n";
std::cout << " <<<< rotary_dim: " << rotary_dim << "\n";
std::cout << " <<<< rotary_stride: " << rotary_stride << "\n";
std::cout << " <<<< input_seq_stride: " << input_seq_stride << "\n";
std::cout << " <<<< input_head_stride: " << input_head_stride << "\n";
std::cout << " <<<< output_seq_stride: " << output_seq_stride << "\n";
std::cout << " <<<< output_head_stride: " << output_head_stride << "\n";
std::cout << " <<<< interleaved: " << interleaved << "\n";
std::cout << " <<<< discrete: " << discrete << "\n";
std::cout << " <<<< dynamic_ntk: " << dynamic_ntk << "\n";
#endif
size_t io_bytes =
(size_t)total_seq_len * head_num * (head_size + rotary_dim) * dtype_size_map[dtype] * 2;
io_bytes += (discrete && !no_offset) ? total_seq_len * sizeof(int) : batch * sizeof(int);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
if (rope_2d) {
TMO_KERNEL_CHECK_FATAL(invokeGlm6BRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, total_seq_len, head_num, head_size,
rotary_seq_len, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, dtype));
} else {
TMO_KERNEL_CHECK_FATAL(invokeRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, head_num, head_size, rotary_seq_len,
rotary_dim, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, discrete, dynamic_ntk, dtype));
}
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,133 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
#define TEST_KERNELS_PYTEST_UTILS_H_
#include <cnrt.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include "aten/cnnl/cnnlHandle.h"
#include "common/utils.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "kernels/kernel_utils.h"
namespace tmo {
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
switch (dtype) {
case torch::kFloat32:
return CNNL_DTYPE_FLOAT;
case torch::kFloat16:
return CNNL_DTYPE_HALF;
case torch::kFloat64:
return CNNL_DTYPE_DOUBLE;
case torch::kInt8:
return CNNL_DTYPE_INT8;
case torch::kInt16:
return CNNL_DTYPE_INT16;
case torch::kInt32:
return CNNL_DTYPE_INT32;
case torch::kInt64:
return CNNL_DTYPE_INT64;
case torch::kUInt8:
return CNNL_DTYPE_UINT8;
case torch::kBool:
return CNNL_DTYPE_BOOL;
case torch::kBFloat16:
return CNNL_DTYPE_BFLOAT16;
default:
throw std::runtime_error("Unsupported torch::Dtype");
}
}
static constexpr int dtype_size_map[] = {
[CNNL_DTYPE_INVALID] = 0,
[CNNL_DTYPE_HALF] = 2,
[CNNL_DTYPE_FLOAT] = 4,
[CNNL_DTYPE_INT8] = 1,
[CNNL_DTYPE_INT16] = 2,
[CNNL_DTYPE_INT31] = 4,
[CNNL_DTYPE_INT32] = 4,
[CNNL_DTYPE_UINT8] = 1,
[CNNL_DTYPE_BOOL] = 1,
[CNNL_DTYPE_INT64] = 8,
[10] = 0,
[CNNL_DTYPE_UINT32] = 4,
[CNNL_DTYPE_UINT64] = 8,
[CNNL_DTYPE_UINT16] = 2,
[CNNL_DTYPE_DOUBLE] = 8,
[CNNL_DTYPE_COMPLEX_HALF] = 4,
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
[CNNL_DTYPE_BFLOAT16] = 2,
};
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
switch (dtype) {
case CNNL_DTYPE_FLOAT:
return torch::kFloat32;
case CNNL_DTYPE_HALF:
return torch::kFloat16;
case CNNL_DTYPE_DOUBLE:
return torch::kFloat64;
case CNNL_DTYPE_INT8:
return torch::kInt8;
case CNNL_DTYPE_INT16:
return torch::kInt16;
case CNNL_DTYPE_INT32:
return torch::kInt32;
case CNNL_DTYPE_INT64:
return torch::kInt64;
case CNNL_DTYPE_UINT8:
return torch::kUInt8;
case CNNL_DTYPE_BOOL:
return torch::kBool;
case CNNL_DTYPE_BFLOAT16:
return torch::kBFloat16;
default:
throw std::runtime_error("Unsupported cnnlDataType_t");
}
}
static float getBandWidth() {
int card = -1;
CNRT_CHECK(cnrtGetDevice(&card));
if (cndevInit(0) != CNDEV_SUCCESS) {
abort();
}
cndevDDRInfo_t ddrinfo;
ddrinfo.version = CNDEV_VERSION_5;
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
abort();
}
double band_width = ddrinfo.bandWidth;
double band_width_decimal = ddrinfo.bandWidthDecimal;
do {
band_width_decimal /= 10;
} while (band_width_decimal > 1);
return float(band_width + band_width_decimal);
}
static void print_info(float time_usec, size_t io_bytes) {
float io_bandwidth = getBandWidth();
std::cout << "kernel time: " << time_usec << "us" << std::endl;
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
}
#define MLU_TENSOR_CHECK_FATAL(tensor) \
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
}
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_UTILS_H_

View File

@@ -0,0 +1,75 @@
import sys
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import math
def alibi_slopes(batch,
num_heads,
true_seq_lens,
train_seq_len,
use_dynamic):
scale = 1.0
if use_dynamic:
# dynamic ntk factor according to actual sequence length
a0 = 1.0
# train_seq_len = 2048
dynamic_seq_len = true_seq_lens # [batch, 1]
a = a0 * dynamic_seq_len / train_seq_len # [batch, 1]
a = a.masked_fill(a < 1.0, 1.0) # dynamic step 1: dynamic ntk scaling factor
scale = a ** (1.0 / (num_heads-1)) # dynamic step 2: coefficient b, for computation convenience
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32
)
if use_dynamic:
base = base / scale # dynamic step 3: divide b to alibi base
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if use_dynamic:
slopes = slopes * scale # dynamic step 4: fix alibi bias m_h by multiplying b
if closest_power_of_2 != num_heads: # todo: fix ntk when num_heads is not power of 2
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
extra_slopes = torch.pow(extra_base, extra_powers)
if use_dynamic:
extra_slopes = extra_slopes.unsqueeze(0)
extra_slopes = extra_slopes.repeat(batch, 1)
slopes = torch.cat([slopes, extra_slopes], dim=-1)
if not use_dynamic:
slopes = slopes.unsqueeze(0)
slopes = slopes.repeat(batch, 1)
return slopes
class TestAlibiSlopeKernel(UnitTest):
@pytest.mark.parametrize("batch, tp_head_num, tp_num, use_dynamic, max_sequence_length", [
(1, 11, 1, False, 1024),
(1, 11, 2, False, 1024),
(8, 16, 1, False, 4096),
(8, 16, 2, False, 4096),
(1, 11, 1, True, 1024),
(1, 11, 2, True, 1024),
(8, 16, 1, True, 4096),
(8, 16, 2, True, 4096)])
def test_alibi_slope(self, batch, tp_head_num, tp_num, use_dynamic, max_sequence_length):
torch.manual_seed(0)
true_seq_lens = torch.randint(0, 2 * max_sequence_length, (batch, 1), dtype=torch.int32)
output_base = alibi_slopes(batch, tp_head_num * tp_num, true_seq_lens, max_sequence_length, use_dynamic)
output_mlu = btunittests.alibi_slope_test(true_seq_lens.mlu(), batch, tp_head_num, tp_num, use_dynamic, max_sequence_length)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 <= 1e-5 and diff2 <= 1e-5
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,109 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
import math
def create_cos_sin_table(rotary_base, rotary_seq_len_, rotary_stride_, rotary_scaling, rotary_dim, interleaved, datatype):
torch.manual_seed(0)
random.seed(0)
inv_freq=1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 2).mlu() / rotary_dim))
t = torch.arange(rotary_seq_len_).mlu()
t = t / rotary_scaling
freqs = torch.outer(t, inv_freq).mlu()
cos_table = torch.cos(freqs).mlu()
sin_table = torch.sin(freqs).mlu()
if interleaved == 0:
cos_table = torch.cat((cos_table,cos_table), dim = 1).mlu()
sin_table = torch.cat((sin_table,sin_table), dim = 1).mlu()
emb = torch.cat((cos_table, sin_table), dim=-1).mlu()
else:
cos_tmp = torch.cat((cos_table[:,0].unsqueeze(1),cos_table[:,0].unsqueeze(1)), dim = 1).mlu()
emb = cos_tmp
for i in range(1, cos_table.size(1)):
cos_tmp = torch.cat((cos_table[:,i].unsqueeze(1),cos_table[:,i].unsqueeze(1)), dim = 1).mlu()
emb = torch.cat((emb,cos_tmp), dim = 1).mlu()
for i in range(0, cos_table.size(1)):
sin_tmp = torch.cat((sin_table[:,i].unsqueeze(1),sin_table[:,i].unsqueeze(1)), dim = 1).mlu()
emb = torch.cat((emb,sin_tmp), dim = 1).mlu()
return emb
def get_ntk_alpha(true_seq_len, max_position_embeddings, rotary_base, rotary_dim):
context_value = math.log(true_seq_len / max_position_embeddings, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
rotary_b = rotary_base * (ntk_alpha ** (rotary_dim / (rotary_dim - 2)))
return ntk_alpha, rotary_b
class TestCreateCosSinTableKernel(UnitTest):
@pytest.mark.parametrize("batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,\
rotary_scaling_type, rotary_dim , interleaved, datatype",
[(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.float32),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.float32),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float32),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float32),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float32),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float32),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float32),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float32),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.bfloat16),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.bfloat16),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float16),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float16),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float16),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float16),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float16),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float16)])
def test_all(self, batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,
rotary_scaling_type, rotary_dim, interleaved ,datatype):
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and datatype == torch.bfloat16:
datatype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
torch.manual_seed(0)
random.seed(0)
seq_lens = torch.randint(size = (batch,), low = 0, high = 512,
dtype = torch.int32).mlu()
ntk_alpha_list = []
rotary_emb_alpha_cached = torch.rand((batch), dtype=torch.float32).mlu()
ref_rotary_emb_alpha_cached = rotary_emb_alpha_cached.clone()
output_mlu = btunittests.create_cos_sin_table_test(ref_rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch_stride,
rotary_seq_len, rotary_dim, rotary_stride, rotary_scaling,
rotary_scaling_type, rotary_base, interleaved)
if rotary_scaling_type != 2:
output_base = create_cos_sin_table(rotary_base, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype)
else:
max_seq_len, max_index = torch.max(seq_lens, dim = 0)
if max_seq_len > max_position_embeddings:
for i in range(batch):
ntk_alpha, rb = get_ntk_alpha(seq_lens[i], max_position_embeddings, rotary_base, rotary_dim)
rotary_emb_alpha_cached[i] = ntk_alpha
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
else:
ntk_alpha, rb = get_ntk_alpha(max_seq_len, max_position_embeddings, rotary_base, rotary_dim)
for i in range(batch):
rotary_emb_alpha_cached[i] = ntk_alpha
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
output_base = cos_sin_table
output_base_1 = rotary_emb_alpha_cached
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 < 0.003 and diff2 < 0.003
if rotary_scaling_type == 2:
diff1 = super().diff1(ref_rotary_emb_alpha_cached, output_base_1)
diff2 = super().diff2(ref_rotary_emb_alpha_cached, output_base_1)
assert diff1 < 0.003 and diff2 < 0.003
assert diff1 < 0.003 and diff2 < 0.003
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,46 @@
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import torch.nn as nn
import btunittests
class TestEmbeddingKernel(UnitTest):
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype", [
(1, 1024, 10000, 128, torch.float16),
(6, 128, 20000, 4096, torch.float16),
(1, 10, 5000, 1024, torch.float32),
(5, 10, 10000, 128, torch.float32)])
def test_all(self, batch, seq, vocab_size, hidden_size, dtype):
torch.manual_seed(0)
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size).to(dtype).mlu()
output_base = embedding_layer(input)
output_mlu = btunittests.embedding_test(input, embedding_layer.weight, 0, vocab_size)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 == 0 and diff2 == 0
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part", [
(1, 1024, 10000, 128, torch.float16, 2000, 1000),
(6, 128, 20000, 4096, torch.float16, 10000, 10000),
(1, 10, 5000, 1024, torch.float32, 0, 2500),
(5, 10, 10000, 128, torch.float32, 100, 512)])
def test_part(self, batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part):
torch.manual_seed(0)
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
input_mask = torch.ones_like(input) * vocab_size
input = torch.where(input < vocab_offset, input_mask, input)
input = torch.where(input >= vocab_offset + vocab_part, input_mask, input)
embedding_layer = nn.Embedding(num_embeddings=vocab_size + 1, embedding_dim=hidden_size, padding_idx=vocab_size).to(dtype).mlu()
output_base = embedding_layer(input)
part_weight = torch.Tensor(embedding_layer.weight)[vocab_offset : vocab_offset + vocab_part, :]
output_mlu = btunittests.embedding_test(input, part_weight, vocab_offset, vocab_part)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 == 0 and diff2 == 0
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,92 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
class TestSliceCuSeqlensKernel(UnitTest):
@pytest.mark.parametrize("batch_size, parallel_num", [
(8,3),
(16,16),
(32,16),
(64,16),
(128,16),
(256,16),
(512,16),
(1024,16),
(2048,16),
(4096,16),
])
def test_all(self, batch_size, parallel_num):
torch.manual_seed(0)
random.seed(0)
if batch_size < parallel_num:
return
cu_seq_lens = torch.randint(low=0, high=1024*1024, size=(batch_size+1,), dtype=torch.int32).mlu()
every = (batch_size + parallel_num - 1) // parallel_num
repeat = batch_size // every
remain = batch_size % every
loop = repeat + (remain != 0)
result = []
for i in range(loop):
elem_num = 1 + (i == loop - 1 and remain != 0 and remain or every)
start_idx = i * every
end_idx = start_idx + elem_num
slice_tensor = cu_seq_lens[start_idx:end_idx]
adjusted_tensor = slice_tensor - slice_tensor[0]
result.append(adjusted_tensor)
sliced_cu_seq_lens = torch.cat(result)
sliced_cu_seq_lens_mlu = torch.zeros(batch_size + loop, dtype=torch.int32).mlu()
btunittests.slice_cu_seq_lens_test(cu_seq_lens, sliced_cu_seq_lens_mlu, batch_size, parallel_num)
sliced_cu_seq_lens_diff1 = super().diff1(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
sliced_cu_seq_lens_diff2 = super().diff2(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
assert sliced_cu_seq_lens_diff1 == 0 and sliced_cu_seq_lens_diff2 == 0
class TestGenerateCuSeqlensKernel(UnitTest):
@pytest.mark.parametrize("seq_len, parallel_num, is_causal_mask, is_kv_seq_len", [
(2154, 4, False, False),
(3412, 3, True, False),
(996, 7, False, True),
(872, 4, False, False),
(634, 12, True, False),
(486, 4, False, True),
(125, 6, False, False),
])
def test_all(self, seq_len, parallel_num, is_causal_mask, is_kv_seq_len):
every = (seq_len + parallel_num - 1) // parallel_num
repeat = seq_len // every
remain = seq_len % every
loop = repeat + (remain != 0)
generate_cu_seq_lens = torch.zeros(2 * loop, dtype=torch.int32).mlu()
rep = seq_len // loop
rem = seq_len % loop
base = rep + (rem != 0)
if is_causal_mask and is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len if i == loop - 1 else (i + 1) * base
elif not is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len - (loop - 1) * base if i == loop - 1 else base
elif not is_causal_mask and is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len
generate_cu_seq_lens_mlu = torch.zeros(2 * loop, dtype=torch.int32).mlu()
btunittests.generate_cu_seq_lens_test(generate_cu_seq_lens_mlu, seq_len, parallel_num, is_causal_mask, is_kv_seq_len)
generate_cu_seq_lens_diff1 = super().diff1(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
generate_cu_seq_lens_diff2 = super().diff2(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
assert generate_cu_seq_lens_diff1 == 0 and generate_cu_seq_lens_diff2 == 0
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,82 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
def round_func(tensor, mode=1):
tensor = tensor.float()
fractional_part = tensor - tensor.floor()
even_mask = (tensor.floor() % 2 == 0)
if mode == 1:
rounded_tensor = torch.where(fractional_part == 0.5,
torch.where(even_mask, tensor.floor(), tensor.ceil()),
torch.round(tensor))
elif mode == 2:
rounded_tensor = torch.where(fractional_part == 0.5, tensor + 0.5, torch.round(tensor))
elif mode == 3:
rounded_tensor = torch.round(tensor)
return rounded_tensor
def compute_with_tensor(src, batch, seq_len, head_num, head_size):
src = src.view(batch * seq_len * head_num * 2, head_size)
mask = torch.ones(src.size(0), dtype=bool)
for i in range(head_num, src.size(0), head_num * 2):
mask[i:i+head_num] = False
src = src[mask]
max_value_tensor, _ = torch.max(src.abs(), dim=1)
scale = (max_value_tensor / 127.0).view(batch, seq_len, head_num)
scale_ori_quant = 127 / max_value_tensor
temp_tensor = src * scale_ori_quant.view(scale_ori_quant.size(0),1)
dst = round_func(temp_tensor, 1).view(batch, seq_len, head_num, head_size)
return scale,dst
class TestPerHeadQuantizeKernel(UnitTest):
@pytest.mark.parametrize("batch, seq_len, head_num, head_size, src_dtype", [
(17,1436,31,21, torch.float32),
(17,1436,31,21, torch.float16),
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float32),
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float16),
(2,3,5,5, torch.float32)
])
def test_all(self, batch, seq_len, head_num, head_size, src_dtype):
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and src_dtype == torch.bfloat16:
src_dtype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
torch.manual_seed(0)
random.seed(0)
hidden = head_num * head_size
seq_stride = 2 * hidden
scale_dtype = torch.float32
dst_dtype = torch.int8
dst_bs_stride = seq_len * hidden
dst_seq_stride = hidden
dst_head_stride = head_size
src_bs_stride = seq_len * seq_stride
src_seq_stride = seq_stride
src_head_stride = head_size
src = torch.rand((batch, seq_len, head_num, 2 * head_size), dtype=src_dtype).mlu()
scale, dst = compute_with_tensor(src, batch, seq_len, head_num, head_size)
scale_mlu = torch.zeros((batch, seq_len, head_num), dtype=scale_dtype).mlu()
dst_mlu = torch.zeros((batch, seq_len, head_num, head_size), dtype=dst_dtype).mlu()
btunittests.per_head_quantize_test(dst_mlu, scale_mlu, src, batch, seq_len, head_num, head_size, dst_bs_stride,
dst_seq_stride, dst_head_stride, src_bs_stride, src_seq_stride,
src_head_stride)
dst_diff1 = super().diff1(dst, dst_mlu)
dst_diff2 = super().diff2(dst, dst_mlu)
scale_diff1 = super().diff1(scale, scale_mlu)
scale_diff2 = super().diff2(scale, scale_mlu)
assert dst_diff1 < 0.003 and dst_diff2 < 0.003
assert scale_diff1 < 0.003 and scale_diff2 < 0.003
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,220 @@
import sys
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
class ApplyRotary(torch.nn.Module):
def __init__(self):
super().__init__()
def rotate(self, x: torch.Tensor, interleaved: bool):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
y = torch.empty_like(x)
x1, x2 = x[..., ::2], x[..., 1::2]
y[..., ::2], y[..., 1::2] = -x2, x1
return y
def forward(self,
output: torch.Tensor, # [total_seqlen, num_heads, head_size]
input: torch.Tensor, # [total_seqlen, num_heads, head_size]
sin_cache: torch.Tensor, # [rope_seqlen, rotary_dim] / [batch, rope_seqlen, rotary_dim]
cos_cache: torch.Tensor, #
seq_offsets: torch.Tensor, # [batch] / [batch, max_seqlen]
cu_seq_lens: torch.Tensor, # [batch + 1]
interleaved: bool,
discrete: bool,
dynamic: bool,
max_seqlen: int,
no_offset:bool,
rope_2d:bool):
packed = input.dim() == 3
rope_dim = sin_cache.shape[-1]
batch_size = cu_seq_lens.shape[0] - 1 if packed else input.shape[0]
if not rope_2d:
for i in range(batch_size):
input_i = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i = input_i[..., 0:rope_dim]
ouput_i = ouput_i[..., 0:rope_dim]
start_seq_idx = 0 if no_offset else seq_offsets[i]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i.shape[0]
if discrete:
start_seq_idx = 0
token_offset = seq_offsets[cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[start_seq_idx:seq+start_seq_idx]
cos_cache_i = cos_cache_i[start_seq_idx:seq+start_seq_idx]
rot = self.rotate(input_i, interleaved)
ouput_i[:] = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1)
output[..., rope_dim:] = input[..., rope_dim:]
else:
for i in range(batch_size):
input_i_left = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i_left = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i_left = input_i_left[..., 0:rope_dim]
ouput_i_left = ouput_i_left[..., 0:rope_dim]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i_left.shape[0]
if discrete:
token_offset = seq_offsets[0][cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[:seq]
cos_cache_i = cos_cache_i[:seq]
rot = self.rotate(input_i_left, interleaved)
ouput_i_left[:] = rot * sin_cache_i.unsqueeze(1) + input_i_left * cos_cache_i.unsqueeze(1)
input_i_right = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i_right = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i_right = input_i_right[..., rope_dim:]
ouput_i_right = ouput_i_right[..., rope_dim:]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i_right.shape[0]
if discrete:
token_offset = seq_offsets[1][cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[:seq]
cos_cache_i = cos_cache_i[:seq]
rot = self.rotate(input_i_right, interleaved)
ouput_i_right[:] = rot * sin_cache_i.unsqueeze(1) + input_i_right * cos_cache_i.unsqueeze(1)
class TestRotaryEmbeddingKernel(UnitTest):
@pytest.mark.parametrize(
"batch, max_seq_len, head_num_q, head_num_kv, head_size, discrete, rotary_dim, interleaved, rope_2d, dynamic_ntk, packed, no_offset, data_type",
[
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float32),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float32),
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float16),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, False, False, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, True, False, True, True, torch.float16),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, False, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, False, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, True, True, True, False, torch.float32),
])
def testall(
self,
batch,
max_seq_len,
head_num_q,
head_num_kv,
head_size,
discrete,
rotary_dim,
interleaved,
rope_2d,
dynamic_ntk,
packed,
no_offset,
data_type):
torch.manual_seed(0)
if rope_2d:
interleaved = False
discrete = True
dynamic_ntk = False
no_offset = False
if head_num_q != 2 * head_num_kv:
head_num_q = 2 * head_num_kv
if rotary_dim %2 != 0:
rotary_dim -= 1
if no_offset:
discrete = False
if rope_2d:
head_size = int(head_size if (head_size % 4 == 0) else (head_size // 4) * 4)
rotary_dim = int(head_size / 2);
seq_lens = torch.randint(size=(batch, ), low=1, high=max_seq_len+1, dtype=torch.int32, device='mlu')
total_seq_len = batch * max_seq_len
if not packed:
seq_lens = torch.randint(size=(batch, ), low=max_seq_len, high=max_seq_len+1, dtype=torch.int32, device='mlu')
max_context_len = int(seq_lens.max())
cu_seq_lens = None
if packed:
cu_seq_lens = torch.cumsum(seq_lens, dim=-1, dtype=torch.int32)
cu_seq_lens = torch.nn.functional.pad(cu_seq_lens, (1,0), "constant", 0).mlu()
total_seq_len = int(cu_seq_lens[-1])
head_num_qkv = head_num_q + head_num_kv
head_num_qk = head_num_q + head_num_kv
context_shape = (total_seq_len, head_num_qkv + 1, head_size) if packed else \
(batch, max_seq_len, head_num_qkv + 1, head_size)
context = torch.randn(size=context_shape, dtype=data_type).mlu()
qkv = context[..., 0 : head_num_qkv, :]
qk = context[..., 0 : head_num_qk, :]
seq_offsets = None
if rope_2d:
seq_offsets = torch.randint(size=(2, total_seq_len), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
else:
if discrete and not no_offset:
seq_offsets = torch.randint(size=(total_seq_len,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
elif not no_offset:
seq_offsets = torch.randint(size=(batch,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
cos_cache = None
sin_cache = None
if dynamic_ntk:
cos_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
sin_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
else:
cos_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
sin_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
base_output = torch.empty_like(qk)
apply_rotary = ApplyRotary()
apply_rotary(base_output, qkv, sin_cache, cos_cache, \
seq_offsets, cu_seq_lens, interleaved, discrete, dynamic_ntk, max_context_len, no_offset, rope_2d)
output = btunittests.rotary_embedding_test(
qkv,
sin_cache,
cos_cache,
seq_offsets,
cu_seq_lens,
interleaved,
discrete,
dynamic_ntk,
rope_2d,
max_context_len,
no_offset)
diff1 = super().diff1(output, base_output)
diff2 = super().diff2(output, base_output)
assert diff1 < 0.003 and diff2 < 0.003
if __name__ == "__main__":
exit_code = pytest.main(
[
"-vs",
os.path.abspath(__file__),
"--junitxml="
+ os.path.splitext(os.path.basename(__file__))[0]
+ "_report.xml",
]
)
exit(exit_code)

View File

@@ -0,0 +1,25 @@
import sys
import os
build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib"
sys.path.append(build_lib_dir)
import torch
class UnitTest:
def diff1(self, result: torch.Tensor, baseline: torch.Tensor):
result = result.flatten().float().to('cpu')
baseline = baseline.flatten().float().to('cpu')
assert result.shape == baseline.shape
error = torch.abs(baseline - result)
denominator = torch.sum(torch.abs(baseline)).item()
eps = 0.0 if denominator > 0 else 1e-9
diff1 = torch.sum(error) / (denominator + eps)
return diff1.item()
def diff2(self, result: torch.Tensor, baseline: torch.Tensor):
result = result.flatten().float().to('cpu')
baseline = baseline.flatten().float().to('cpu')
error = torch.abs(baseline - result)
denominator = torch.sum(baseline**2).item()
eps = 0.0 if denominator > 0 else 1e-9
diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps))
return diff2.item()