forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
101
torch_mlu_ops-v1.3.2/tests/kernels_pytest/CMakeLists.txt
Normal file
101
torch_mlu_ops-v1.3.2/tests/kernels_pytest/CMakeLists.txt
Normal file
@@ -0,0 +1,101 @@
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
set(CMAKE_C_COMPILER "gcc")
|
||||
set(CMAKE_CXX_COMPILER "g++")
|
||||
|
||||
project(kernel_test)
|
||||
message(STATUS "project name: ${PROJECT_NAME}")
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib")
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/archive")
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake/modules")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
add_compile_options(-std=c++17 -O3 -g -fPIC -Wall -Werror -Wextra -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unknown-pragmas)
|
||||
|
||||
link_directories($ENV{NEUWARE_HOME}/lib64)
|
||||
link_libraries(m stdc++ dl pthread cnnl cnnl_extra cnrt cndrv "-Wl,-rpath,$ENV{NEUWARE_HOME}/lib64 -Wl,--disable-new-dtags")
|
||||
include_directories($ENV{NEUWARE_HOME}/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/ ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/kernels)
|
||||
|
||||
function(find_torch)
|
||||
# get the path and cxx11_abi flag
|
||||
execute_process(
|
||||
COMMAND python3 -c "import torch; print(torch.__path__[0], torch.compiled_with_cxx11_abi(), sep=';')"
|
||||
RESULT_VARIABLE TORCH_NOT_FOUND
|
||||
OUTPUT_VARIABLE TORCH_INFO
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(TORCH_NOT_FOUND)
|
||||
return()
|
||||
endif()
|
||||
|
||||
list(GET TORCH_INFO 0 TORCH_PATH)
|
||||
message(STATUS "torch path: ${TORCH_PATH}")
|
||||
|
||||
list(GET TORCH_INFO 1 TORCH_CXX11_ABI)
|
||||
message(STATUS "torch cxx11 abi: ${TORCH_CXX11_ABI}")
|
||||
|
||||
set(Torch_DIR ${TORCH_PATH}/share/cmake/Torch PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# import pytorch
|
||||
find_torch()
|
||||
message(STATUS "Torch_DIR: ${Torch_DIR}")
|
||||
find_package(Torch QUIET)
|
||||
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
# import torch_mlu
|
||||
execute_process(
|
||||
COMMAND python3 -c "import torch_mlu.utils as mlu_utils;print(mlu_utils.cmake_prefix_path)"
|
||||
OUTPUT_VARIABLE Torch_MLU_MODULE_DIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
# find ops library
|
||||
execute_process(
|
||||
COMMAND python3 -c "import torch_mlu_ops as ops;print(ops._utils.get_custom_op_library_path())"
|
||||
RESULT_VARIABLE LIBOPS_NOT_FOUND
|
||||
OUTPUT_VARIABLE LIBOPS_PATH
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
if(LIBOPS_NOT_FOUND)
|
||||
message(FATAL_ERROR "torch_mlu_ops not installed, can not find ops library.")
|
||||
endif()
|
||||
|
||||
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${Torch_MLU_MODULE_DIR})
|
||||
find_package(TorchMLU QUIET)
|
||||
# torch_mlu throw [-Werror=sign-compare] error, so ignore it
|
||||
add_compile_options(-Wno-sign-compare)
|
||||
# TorchMLUConfig.cmake run will get TORCH_ATEN_LIBRARY-NOTFOUND, it will cause compile fail, remove it
|
||||
string(REPLACE "TORCH_ATEN_LIBRARY-NOTFOUND" "" TORCH_MLU_LIBRARIES_MODIFIED "${TORCH_MLU_LIBRARIES}")
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 -c "from distutils import sysconfig; print(sysconfig.get_python_inc())"
|
||||
OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(Torch_FOUND AND ((TORCH_CXX11_ABI AND USE_CXX11_ABI) OR (NOT TORCH_CXX11_ABI AND NOT USE_CXX11_ABI)))
|
||||
include_directories(${PYTHON_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} ${TORCH_MLU_INCLUDE_DIRS})
|
||||
link_libraries(${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} ${TORCH_MLU_LIBRARIES_MODIFIED})
|
||||
file(GLOB_RECURSE TEST_SRCS "src/*.cpp" RECURSE)
|
||||
execute_process(
|
||||
COMMAND python3 -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))"
|
||||
OUTPUT_VARIABLE TEST_LIB_NAME_SUFFIX
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
message("${PYTHON_EXTENSION}")
|
||||
set(CMAKE_SHARED_LIBRARY_PREFIX "")
|
||||
set(CMAKE_SHARED_LIBRARY_SUFFIX "")
|
||||
set(TEST_LIB_NAME_PREFIX "btunittests")
|
||||
string(APPEND TEST_LIB_NAME "${TEST_LIB_NAME_PREFIX}${TEST_LIB_NAME_SUFFIX}")
|
||||
add_library(${TEST_LIB_NAME} SHARED ${TEST_SRCS})
|
||||
target_link_libraries(${TEST_LIB_NAME} ${LIBOPS_PATH} -lstdc++fs)
|
||||
else()
|
||||
message(STATUS "Torch not found, or torch abi is different with which you specified, will not build")
|
||||
message(STATUS "if torch not found, please set env Torch_DIR to the directory containing TorchConfig.cmake")
|
||||
endif()
|
||||
17
torch_mlu_ops-v1.3.2/tests/kernels_pytest/README.md
Normal file
17
torch_mlu_ops-v1.3.2/tests/kernels_pytest/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
## BT_OPS测试脚本使用方式
|
||||
|
||||
```bash
|
||||
# 测试所有测例
|
||||
bash run_test.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
# 测试单个测例
|
||||
python3 test_测例名称.py
|
||||
```
|
||||
|
||||
- 必须在Torch-MLU-Ops docker容器内运行。
|
||||
|
||||
- 测试脚本的命名规则为 `test_测例名称.py`。
|
||||
|
||||
- 必须保证 Torch-MLU-Ops whl包正确安装。
|
||||
10
torch_mlu_ops-v1.3.2/tests/kernels_pytest/build.sh
Executable file
10
torch_mlu_ops-v1.3.2/tests/kernels_pytest/build.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
||||
cd ${TOP_DIR}
|
||||
|
||||
rm -rf build > /dev/null 2>&1
|
||||
cmake $TOP_DIR -Bbuild $@
|
||||
cmake --build build -- -j32
|
||||
22
torch_mlu_ops-v1.3.2/tests/kernels_pytest/run_test.sh
Executable file
22
torch_mlu_ops-v1.3.2/tests/kernels_pytest/run_test.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
||||
tmo_kernel_case=$(find "${SCRIPT_DIR}" -name "test_*.py")
|
||||
coverage=${1}
|
||||
|
||||
for sc in ${tmo_kernel_case}
|
||||
do
|
||||
echo -n "${sc} "
|
||||
echo -n "Testing...."
|
||||
if [ "${coverage}" = "coverage" ];then
|
||||
coverage run -a ${sc}
|
||||
else
|
||||
python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1
|
||||
fi
|
||||
if [ $? == 0 ];then
|
||||
echo -e "\033[32m success \033[0m"
|
||||
else
|
||||
echo -e "\033[31m failed \033[0m"
|
||||
fi
|
||||
done
|
||||
echo "End of pytest..."
|
||||
@@ -0,0 +1,54 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "kernels/generate_alibi_slope.mluh"
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
|
||||
const int batch,
|
||||
const int tp_head_num,
|
||||
const int tp_num,
|
||||
const bool use_dynamic,
|
||||
const int max_sequence_length) {
|
||||
MLU_TENSOR_CHECK_FATAL(true_seq_lens);
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
|
||||
int head_num = tp_head_num * tp_num;
|
||||
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(true_seq_lens.device());
|
||||
torch::Tensor output = torch::zeros(
|
||||
{batch, head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
|
||||
|
||||
for (int tp_id = 0; tp_id < tp_num; tp_id++) {
|
||||
torch::Tensor tp_slopes = torch::zeros(
|
||||
{batch, tp_head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
|
||||
TMO_KERNEL_CHECK_FATAL(invokeGenerateAlibiSlope(
|
||||
queue, tp_slopes.data_ptr(), true_seq_lens.data_ptr(), batch, tp_id * tp_head_num,
|
||||
tp_head_num, head_num, max_sequence_length, use_dynamic));
|
||||
for (int batch_id = 0; batch_id < batch; batch_id++) {
|
||||
CNRT_CHECK(
|
||||
cnrtMemcpyAsync((float *)output.data_ptr() + batch_id * head_num + tp_id * tp_head_num,
|
||||
(float *)tp_slopes.data_ptr() + batch_id * tp_head_num,
|
||||
tp_head_num * sizeof(float), queue, cnrtMemcpyDevToDev));
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,83 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "kernels/create_cos_sin_table.mluh"
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace {
|
||||
constexpr int DYNAMIC_NTK_SCALING = 2;
|
||||
}
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
|
||||
const at::Tensor &seq_lens,
|
||||
const int max_position_embeddings,
|
||||
const int batch_stride,
|
||||
const int rotary_seq_len,
|
||||
const int rotary_dim,
|
||||
const int rotary_stride,
|
||||
const float rotary_scaling,
|
||||
const int rotary_scaling_type,
|
||||
const float rotary_base,
|
||||
const bool interleaved) {
|
||||
MLU_TENSOR_CHECK_FATAL(rotary_emb_alpha_cached);
|
||||
MLU_TENSOR_CHECK_FATAL(seq_lens);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
cnnlDataType_t dtype =
|
||||
torch2cnnlDataType(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()));
|
||||
|
||||
std::shared_ptr<torch::Device> dev =
|
||||
std::make_shared<torch::Device>(rotary_emb_alpha_cached.device());
|
||||
int batch = seq_lens.size(0);
|
||||
|
||||
torch::Tensor cos_sin_table =
|
||||
rotary_scaling_type == DYNAMIC_NTK_SCALING
|
||||
? torch::zeros({batch, rotary_seq_len, rotary_stride},
|
||||
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
|
||||
.device(*dev)
|
||||
.requires_grad(false))
|
||||
: torch::zeros({rotary_seq_len, rotary_stride},
|
||||
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
|
||||
.device(*dev)
|
||||
.requires_grad(false));
|
||||
|
||||
size_t io_bytes = 0;
|
||||
|
||||
cnrtNotifier_t time_begin;
|
||||
cnrtNotifier_t time_end;
|
||||
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_end));
|
||||
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
|
||||
TMO_KERNEL_CHECK_FATAL(invokeCreateCosSinTable(
|
||||
queue, cos_sin_table.data_ptr(),
|
||||
reinterpret_cast<float *>(rotary_emb_alpha_cached.data_ptr()),
|
||||
reinterpret_cast<int *>(seq_lens.data_ptr()), max_position_embeddings, batch, batch_stride,
|
||||
rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, rotary_scaling_type,
|
||||
interleaved, dtype));
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
|
||||
cnrtQueueSync(queue);
|
||||
float usec = 0.0f;
|
||||
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
|
||||
print_info(usec, io_bytes);
|
||||
|
||||
return cos_sin_table;
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,48 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "kernels/embedding.mluh"
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
at::Tensor embedding_test(const at::Tensor &input,
|
||||
const at::Tensor &weight,
|
||||
const int vocab_offset,
|
||||
const int vocab_part) {
|
||||
MLU_TENSOR_CHECK_FATAL(input);
|
||||
MLU_TENSOR_CHECK_FATAL(weight);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
|
||||
int batch = input.size(0);
|
||||
int seq = input.size(1);
|
||||
int total_vocab_size = weight.size(0);
|
||||
int hidden_size = weight.size(1);
|
||||
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(weight.dtype()));
|
||||
torch::Tensor output = torch::zeros(
|
||||
{batch, seq, hidden_size},
|
||||
torch::dtype(torch::typeMetaToScalarType(weight.dtype())).device(*dev).requires_grad(false));
|
||||
|
||||
TMO_KERNEL_CHECK_FATAL(invokeEmbedding(queue, weight.data_ptr(), input.data_ptr(),
|
||||
output.data_ptr(), dtype, vocab_offset, vocab_part,
|
||||
total_vocab_size, hidden_size, batch * seq));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,80 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "kernels/operate_cu_seq_lens.mluh"
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
|
||||
at::Tensor &sliced_cu_seq_lens,
|
||||
int batch,
|
||||
int parallel_num) {
|
||||
MLU_TENSOR_CHECK_FATAL(cu_seq_lens);
|
||||
MLU_TENSOR_CHECK_FATAL(sliced_cu_seq_lens);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
|
||||
size_t io_bytes = 0;
|
||||
|
||||
cnrtNotifier_t time_begin;
|
||||
cnrtNotifier_t time_end;
|
||||
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_end));
|
||||
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
|
||||
TMO_KERNEL_CHECK_FATAL(invokeSliceCuSeqlens(queue, (int *)cu_seq_lens.data_ptr(),
|
||||
(int *)sliced_cu_seq_lens.data_ptr(), batch,
|
||||
parallel_num));
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
|
||||
cnrtQueueSync(queue);
|
||||
float usec = 0.0f;
|
||||
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
|
||||
print_info(usec, io_bytes);
|
||||
}
|
||||
|
||||
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
|
||||
int seq_len,
|
||||
int parallel_num,
|
||||
bool is_causal_mask,
|
||||
bool is_kv_seq_len) {
|
||||
MLU_TENSOR_CHECK_FATAL(gen_cu_seq_lens);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
|
||||
size_t io_bytes = 0;
|
||||
|
||||
cnrtNotifier_t time_begin;
|
||||
cnrtNotifier_t time_end;
|
||||
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_end));
|
||||
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
|
||||
TMO_KERNEL_CHECK_FATAL(invokeGenerateCuSeqlens(queue, (int *)gen_cu_seq_lens.data_ptr(), seq_len,
|
||||
parallel_num, is_causal_mask, is_kv_seq_len));
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
|
||||
cnrtQueueSync(queue);
|
||||
float usec = 0.0f;
|
||||
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
|
||||
print_info(usec, io_bytes);
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,67 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "kernels/quantize.mluh"
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
void per_head_quantize_test(at::Tensor &dst,
|
||||
at::Tensor &scale,
|
||||
const at::Tensor &src,
|
||||
int bs,
|
||||
int seq_len,
|
||||
int head_num,
|
||||
int head_size,
|
||||
int dst_bs_stride,
|
||||
int dst_seq_stride,
|
||||
int dst_head_stride,
|
||||
int src_bs_stride,
|
||||
int src_seq_stride,
|
||||
int src_head_stride) {
|
||||
MLU_TENSOR_CHECK_FATAL(dst);
|
||||
MLU_TENSOR_CHECK_FATAL(scale);
|
||||
MLU_TENSOR_CHECK_FATAL(src);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
|
||||
cnnlDataType_t dst_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(dst.dtype()));
|
||||
cnnlDataType_t scale_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(scale.dtype()));
|
||||
cnnlDataType_t src_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(src.dtype()));
|
||||
|
||||
size_t io_bytes =
|
||||
bs * seq_len * head_num * head_size * (dtype_size_map[src_dtype] + dtype_size_map[dst_dtype]);
|
||||
|
||||
cnrtNotifier_t time_begin;
|
||||
cnrtNotifier_t time_end;
|
||||
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_end));
|
||||
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
|
||||
TMO_KERNEL_CHECK_FATAL(invokeMluQuantizePerHead(
|
||||
queue, (void *)dst.data_ptr(), (void *)scale.data_ptr(), (void *)src.data_ptr(), dst_dtype,
|
||||
scale_dtype, src_dtype, bs, seq_len, head_num, head_size, dst_bs_stride, dst_seq_stride,
|
||||
dst_head_stride, src_bs_stride, src_seq_stride, src_head_stride));
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
|
||||
cnrtQueueSync(queue);
|
||||
float usec = 0.0f;
|
||||
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
|
||||
print_info(usec, io_bytes);
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
82
torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_api.h
Normal file
82
torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/register_api.h
Normal file
@@ -0,0 +1,82 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef TEST_KERNELS_PYTEST_REGISTER_API_H_
|
||||
#define TEST_KERNELS_PYTEST_REGISTER_API_H_
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
at::Tensor embedding_test(const at::Tensor &input,
|
||||
const at::Tensor &weight,
|
||||
const int vocab_offset,
|
||||
const int vocab_part);
|
||||
|
||||
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
|
||||
const int batch,
|
||||
const int tp_head_num,
|
||||
const int tp_num,
|
||||
const bool use_dynamic,
|
||||
const int max_sequence_length);
|
||||
|
||||
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
|
||||
const at::Tensor &seq_lens,
|
||||
const int max_position_embeddings,
|
||||
const int batch_stride,
|
||||
const int rotary_seq_len,
|
||||
const int rotary_dim,
|
||||
const int rotary_stride,
|
||||
const float rotary_scaling,
|
||||
const int rotary_scaling_type,
|
||||
const float rotary_base,
|
||||
const bool interleaved);
|
||||
|
||||
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
|
||||
at::Tensor &sliced_cu_seq_lens,
|
||||
int batch,
|
||||
int parallel_num);
|
||||
|
||||
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
|
||||
int seq_len,
|
||||
int parallel_num,
|
||||
bool is_causal_mask,
|
||||
bool is_kv_seq_len);
|
||||
|
||||
void per_head_quantize_test(at::Tensor &dst,
|
||||
at::Tensor &scale,
|
||||
const at::Tensor &src,
|
||||
int bs,
|
||||
int seq_len,
|
||||
int head_num,
|
||||
int head_size,
|
||||
int dst_bs_stride,
|
||||
int dst_seq_stride,
|
||||
int dst_head_stride,
|
||||
int src_bs_stride,
|
||||
int src_seq_stride,
|
||||
int src_head_stride);
|
||||
|
||||
at::Tensor rotary_embedding_test(const at::Tensor &input,
|
||||
const at::Tensor &sin_table,
|
||||
const at::Tensor &cos_table,
|
||||
const c10::optional<at::Tensor> &seq_offsets,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens,
|
||||
const bool &interleaved,
|
||||
const bool &discrete,
|
||||
const bool &dynamic_ntk,
|
||||
const bool &rope_2d,
|
||||
const int &max_seq_len,
|
||||
const bool &no_offset);
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
#endif // TEST_KERNELS_PYTEST_REGISTER_API_H_
|
||||
@@ -0,0 +1,29 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <torch/extension.h>
|
||||
#include "register_api.h"
|
||||
|
||||
PYBIND11_MODULE(btunittests, m) {
|
||||
m.def("embedding_test", &tmo::kernel_test_api::embedding_test, "Test embedding kernel function.");
|
||||
m.def("alibi_slope_test", &tmo::kernel_test_api::alibi_slope_test,
|
||||
"Test alibi_slope kernel function.");
|
||||
m.def("create_cos_sin_table_test", &tmo::kernel_test_api::create_cos_sin_table_test,
|
||||
"Test create_cos_sin_table_test kernel function.");
|
||||
m.def("slice_cu_seq_lens_test", &tmo::kernel_test_api::slice_cu_seq_lens_test,
|
||||
"Test slice_cu_seq_lens_test kernel function.");
|
||||
m.def("generate_cu_seq_lens_test", &tmo::kernel_test_api::generate_cu_seq_lens_test,
|
||||
"Test generate_cu_seq_lens_test kernel function.");
|
||||
m.def("per_head_quantize_test", &tmo::kernel_test_api::per_head_quantize_test,
|
||||
"Test per_head_quantize_test kernel function.");
|
||||
m.def("rotary_embedding_test", &tmo::kernel_test_api::rotary_embedding_test,
|
||||
"Test rotary_embedding kernel function.");
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <cnnl.h>
|
||||
#include <torch/extension.h>
|
||||
#include "kernels/rotary_embedding.mluh"
|
||||
#include "register_api.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace kernel_test_api {
|
||||
at::Tensor rotary_embedding_test(const at::Tensor &input,
|
||||
const at::Tensor &sin_table,
|
||||
const at::Tensor &cos_table,
|
||||
const c10::optional<at::Tensor> &seq_offsets,
|
||||
const c10::optional<at::Tensor> &cu_seq_lens,
|
||||
const bool &interleaved,
|
||||
const bool &discrete,
|
||||
const bool &dynamic_ntk,
|
||||
const bool &rope_2d,
|
||||
const int &max_seq_len,
|
||||
const bool &no_offset) {
|
||||
MLU_TENSOR_CHECK_FATAL(input);
|
||||
MLU_TENSOR_CHECK_FATAL(sin_table);
|
||||
MLU_TENSOR_CHECK_FATAL(cos_table);
|
||||
|
||||
bool has_seq_offsets = seq_offsets.has_value();
|
||||
bool has_cu_seq_lens = cu_seq_lens.has_value();
|
||||
int *seq_offsets_ptr =
|
||||
has_seq_offsets ? reinterpret_cast<int *>(seq_offsets.value().data_ptr()) : nullptr;
|
||||
int *cu_seq_lens_ptr =
|
||||
has_cu_seq_lens ? reinterpret_cast<int *>(cu_seq_lens.value().data_ptr()) : nullptr;
|
||||
|
||||
int batch = 0;
|
||||
int total_seq_len = 0;
|
||||
int head_size = input.size(-1);
|
||||
|
||||
if (input.dim() == 3) { // pack mode
|
||||
TORCH_CHECK(has_cu_seq_lens,
|
||||
"input has 3 dims: (total_seq_len, head_num, head_size),"
|
||||
" which means pack mode, cu_seqlens should not be None");
|
||||
total_seq_len = input.size(0);
|
||||
batch = cu_seq_lens.value().size(0) - 1;
|
||||
} else if (input.dim() == 4) {
|
||||
TORCH_CHECK(!has_cu_seq_lens,
|
||||
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
|
||||
" which means pad mode, cu_seqlens should be None");
|
||||
TORCH_CHECK(max_seq_len == input.size(1),
|
||||
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
|
||||
" which means pad mode, max_seqlen must be equtals to input.size(1)");
|
||||
batch = input.size(0);
|
||||
total_seq_len = batch * input.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "input only support 3 or 4 dims");
|
||||
}
|
||||
|
||||
const int rope_seqlen = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
|
||||
const int rope_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
|
||||
|
||||
#if 0
|
||||
if (has_seq_offsets) {
|
||||
if (discrete) {
|
||||
CHECK_SHAPE(seq_offsets.value(), total_seq_len);
|
||||
} else {
|
||||
CHECK_SHAPE(seq_offsets.value(), batch);
|
||||
}
|
||||
}
|
||||
if (dynamic_ntk) {
|
||||
CHECK_SHAPE(sin_table, batch, rope_seqlen, rope_dim);
|
||||
CHECK_SHAPE(cos_table, batch, rope_seqlen, rope_dim);
|
||||
} else {
|
||||
CHECK_SHAPE(sin_table, rope_seqlen, rope_dim);
|
||||
CHECK_SHAPE(cos_table, rope_seqlen, rope_dim);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
|
||||
if (dynamic_ntk) {
|
||||
TORCH_CHECK(sin_table.stride(1) == cos_table.stride(1),
|
||||
"sin_table second stride must be equal to cos_table second stride");
|
||||
} else {
|
||||
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
|
||||
"sin_table first stride must be equal to cos_table second stride");
|
||||
}
|
||||
|
||||
if (has_seq_offsets) {
|
||||
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
|
||||
}
|
||||
|
||||
if (has_cu_seq_lens) {
|
||||
TORCH_CHECK(cu_seq_lens.value().is_contiguous(), "cu_seq_lens must be contiguous");
|
||||
}
|
||||
|
||||
// auto output = input;
|
||||
// if (is_qkv) {
|
||||
// }
|
||||
|
||||
int dims = input.dim();
|
||||
int head_num = input.size(dims - 2); // head_num_qk
|
||||
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
|
||||
int head_dim = input.size(dims - 1);
|
||||
int rotary_seq_len = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
|
||||
int rotary_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
|
||||
|
||||
#if 0
|
||||
if (rope_2d) {
|
||||
total_seq_len = input.size(0);
|
||||
rotary_seq_len *= 2;
|
||||
}
|
||||
#endif
|
||||
int rotary_stride = dynamic_ntk ? sin_table.stride(1) : sin_table.stride(0);
|
||||
|
||||
size_t input_seq_stride = input.stride(dims - 3);
|
||||
size_t input_head_stride = input.stride(dims - 2);
|
||||
|
||||
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
|
||||
cnrtQueue_t queue;
|
||||
cnnlGetQueue(handle, &queue);
|
||||
|
||||
#if 0
|
||||
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
|
||||
|
||||
std::cout
|
||||
<< " batch * max_seq_len : " << batch * max_seq_len
|
||||
<< " max_seq_len : " << max_seq_len
|
||||
<< " total_seq_len : " << total_seq_len
|
||||
<< " head_num : " << head_num
|
||||
<< " head_size : " << head_size << "\n";
|
||||
|
||||
// at::IntArrayRef output_shape = {batch, max_seq_len, head_num, head_size};
|
||||
// if (input.dim() == 3) {
|
||||
// output_shape = {batch * max_seq_len, head_num, head_size};
|
||||
// }
|
||||
#endif
|
||||
|
||||
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(input.dtype()));
|
||||
auto output = input;
|
||||
|
||||
size_t output_seq_stride = output.stride(dims - 3);
|
||||
size_t output_head_stride = output.stride(dims - 2);
|
||||
|
||||
#if 0
|
||||
std::cout << " <<<< batch: " << batch << "\n";
|
||||
std::cout << " <<<< max_seq_len: " << max_seq_len << "\n";
|
||||
std::cout << " <<<< total_seq_len: " << total_seq_len << "\n";
|
||||
std::cout << " <<<< head_num: " << head_num << "\n";
|
||||
std::cout << " <<<< head_size: " << head_size << "\n";
|
||||
std::cout << " <<<< rotary_seq_len_: " << rotary_seq_len << "\n";
|
||||
std::cout << " <<<< rotary_dim: " << rotary_dim << "\n";
|
||||
std::cout << " <<<< rotary_stride: " << rotary_stride << "\n";
|
||||
std::cout << " <<<< input_seq_stride: " << input_seq_stride << "\n";
|
||||
std::cout << " <<<< input_head_stride: " << input_head_stride << "\n";
|
||||
std::cout << " <<<< output_seq_stride: " << output_seq_stride << "\n";
|
||||
std::cout << " <<<< output_head_stride: " << output_head_stride << "\n";
|
||||
std::cout << " <<<< interleaved: " << interleaved << "\n";
|
||||
std::cout << " <<<< discrete: " << discrete << "\n";
|
||||
std::cout << " <<<< dynamic_ntk: " << dynamic_ntk << "\n";
|
||||
#endif
|
||||
|
||||
size_t io_bytes =
|
||||
(size_t)total_seq_len * head_num * (head_size + rotary_dim) * dtype_size_map[dtype] * 2;
|
||||
io_bytes += (discrete && !no_offset) ? total_seq_len * sizeof(int) : batch * sizeof(int);
|
||||
|
||||
cnrtNotifier_t time_begin;
|
||||
cnrtNotifier_t time_end;
|
||||
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
|
||||
CNRT_CHECK(cnrtNotifierCreate(&time_end));
|
||||
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
|
||||
if (rope_2d) {
|
||||
TMO_KERNEL_CHECK_FATAL(invokeGlm6BRotaryEmbedding(
|
||||
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
|
||||
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, total_seq_len, head_num, head_size,
|
||||
rotary_seq_len, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
|
||||
output_head_stride, interleaved, dtype));
|
||||
} else {
|
||||
TMO_KERNEL_CHECK_FATAL(invokeRotaryEmbedding(
|
||||
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
|
||||
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, head_num, head_size, rotary_seq_len,
|
||||
rotary_dim, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
|
||||
output_head_stride, interleaved, discrete, dynamic_ntk, dtype));
|
||||
}
|
||||
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
|
||||
cnrtQueueSync(queue);
|
||||
float usec = 0.0f;
|
||||
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
|
||||
print_info(usec, io_bytes);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace kernel_test_api
|
||||
} // namespace tmo
|
||||
133
torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h
Normal file
133
torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h
Normal file
@@ -0,0 +1,133 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
|
||||
#define TEST_KERNELS_PYTEST_UTILS_H_
|
||||
#include <cnrt.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include "aten/cnnl/cnnlHandle.h"
|
||||
#include "common/utils.h"
|
||||
#include "framework/core/MLUStream.h"
|
||||
#include "framework/core/caching_allocator.h"
|
||||
#include "framework/core/device.h"
|
||||
#include "framework/core/mlu_guard.h"
|
||||
#include "kernels/kernel_utils.h"
|
||||
|
||||
namespace tmo {
|
||||
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case torch::kFloat32:
|
||||
return CNNL_DTYPE_FLOAT;
|
||||
case torch::kFloat16:
|
||||
return CNNL_DTYPE_HALF;
|
||||
case torch::kFloat64:
|
||||
return CNNL_DTYPE_DOUBLE;
|
||||
case torch::kInt8:
|
||||
return CNNL_DTYPE_INT8;
|
||||
case torch::kInt16:
|
||||
return CNNL_DTYPE_INT16;
|
||||
case torch::kInt32:
|
||||
return CNNL_DTYPE_INT32;
|
||||
case torch::kInt64:
|
||||
return CNNL_DTYPE_INT64;
|
||||
case torch::kUInt8:
|
||||
return CNNL_DTYPE_UINT8;
|
||||
case torch::kBool:
|
||||
return CNNL_DTYPE_BOOL;
|
||||
case torch::kBFloat16:
|
||||
return CNNL_DTYPE_BFLOAT16;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported torch::Dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int dtype_size_map[] = {
|
||||
[CNNL_DTYPE_INVALID] = 0,
|
||||
[CNNL_DTYPE_HALF] = 2,
|
||||
[CNNL_DTYPE_FLOAT] = 4,
|
||||
[CNNL_DTYPE_INT8] = 1,
|
||||
[CNNL_DTYPE_INT16] = 2,
|
||||
[CNNL_DTYPE_INT31] = 4,
|
||||
[CNNL_DTYPE_INT32] = 4,
|
||||
[CNNL_DTYPE_UINT8] = 1,
|
||||
[CNNL_DTYPE_BOOL] = 1,
|
||||
[CNNL_DTYPE_INT64] = 8,
|
||||
[10] = 0,
|
||||
[CNNL_DTYPE_UINT32] = 4,
|
||||
[CNNL_DTYPE_UINT64] = 8,
|
||||
[CNNL_DTYPE_UINT16] = 2,
|
||||
[CNNL_DTYPE_DOUBLE] = 8,
|
||||
[CNNL_DTYPE_COMPLEX_HALF] = 4,
|
||||
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
|
||||
[CNNL_DTYPE_BFLOAT16] = 2,
|
||||
};
|
||||
|
||||
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
|
||||
switch (dtype) {
|
||||
case CNNL_DTYPE_FLOAT:
|
||||
return torch::kFloat32;
|
||||
case CNNL_DTYPE_HALF:
|
||||
return torch::kFloat16;
|
||||
case CNNL_DTYPE_DOUBLE:
|
||||
return torch::kFloat64;
|
||||
case CNNL_DTYPE_INT8:
|
||||
return torch::kInt8;
|
||||
case CNNL_DTYPE_INT16:
|
||||
return torch::kInt16;
|
||||
case CNNL_DTYPE_INT32:
|
||||
return torch::kInt32;
|
||||
case CNNL_DTYPE_INT64:
|
||||
return torch::kInt64;
|
||||
case CNNL_DTYPE_UINT8:
|
||||
return torch::kUInt8;
|
||||
case CNNL_DTYPE_BOOL:
|
||||
return torch::kBool;
|
||||
case CNNL_DTYPE_BFLOAT16:
|
||||
return torch::kBFloat16;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported cnnlDataType_t");
|
||||
}
|
||||
}
|
||||
|
||||
static float getBandWidth() {
|
||||
int card = -1;
|
||||
CNRT_CHECK(cnrtGetDevice(&card));
|
||||
if (cndevInit(0) != CNDEV_SUCCESS) {
|
||||
abort();
|
||||
}
|
||||
cndevDDRInfo_t ddrinfo;
|
||||
ddrinfo.version = CNDEV_VERSION_5;
|
||||
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
|
||||
abort();
|
||||
}
|
||||
double band_width = ddrinfo.bandWidth;
|
||||
double band_width_decimal = ddrinfo.bandWidthDecimal;
|
||||
do {
|
||||
band_width_decimal /= 10;
|
||||
} while (band_width_decimal > 1);
|
||||
return float(band_width + band_width_decimal);
|
||||
}
|
||||
|
||||
static void print_info(float time_usec, size_t io_bytes) {
|
||||
float io_bandwidth = getBandWidth();
|
||||
std::cout << "kernel time: " << time_usec << "us" << std::endl;
|
||||
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
|
||||
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
|
||||
}
|
||||
|
||||
#define MLU_TENSOR_CHECK_FATAL(tensor) \
|
||||
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
|
||||
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
#endif // TEST_KERNELS_PYTEST_UTILS_H_
|
||||
@@ -0,0 +1,75 @@
|
||||
import sys
|
||||
import os
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import btunittests
|
||||
|
||||
import math
|
||||
|
||||
def alibi_slopes(batch,
|
||||
num_heads,
|
||||
true_seq_lens,
|
||||
train_seq_len,
|
||||
use_dynamic):
|
||||
scale = 1.0
|
||||
if use_dynamic:
|
||||
# dynamic ntk factor according to actual sequence length
|
||||
a0 = 1.0
|
||||
# train_seq_len = 2048
|
||||
dynamic_seq_len = true_seq_lens # [batch, 1]
|
||||
a = a0 * dynamic_seq_len / train_seq_len # [batch, 1]
|
||||
a = a.masked_fill(a < 1.0, 1.0) # dynamic step 1: dynamic ntk scaling factor
|
||||
scale = a ** (1.0 / (num_heads-1)) # dynamic step 2: coefficient b, for computation convenience
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32
|
||||
)
|
||||
if use_dynamic:
|
||||
base = base / scale # dynamic step 3: divide b to alibi base
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
if use_dynamic:
|
||||
slopes = slopes * scale # dynamic step 4: fix alibi bias m_h by multiplying b
|
||||
|
||||
if closest_power_of_2 != num_heads: # todo: fix ntk when num_heads is not power of 2
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
|
||||
extra_slopes = torch.pow(extra_base, extra_powers)
|
||||
if use_dynamic:
|
||||
extra_slopes = extra_slopes.unsqueeze(0)
|
||||
extra_slopes = extra_slopes.repeat(batch, 1)
|
||||
slopes = torch.cat([slopes, extra_slopes], dim=-1)
|
||||
|
||||
if not use_dynamic:
|
||||
slopes = slopes.unsqueeze(0)
|
||||
slopes = slopes.repeat(batch, 1)
|
||||
|
||||
return slopes
|
||||
|
||||
class TestAlibiSlopeKernel(UnitTest):
|
||||
@pytest.mark.parametrize("batch, tp_head_num, tp_num, use_dynamic, max_sequence_length", [
|
||||
(1, 11, 1, False, 1024),
|
||||
(1, 11, 2, False, 1024),
|
||||
(8, 16, 1, False, 4096),
|
||||
(8, 16, 2, False, 4096),
|
||||
(1, 11, 1, True, 1024),
|
||||
(1, 11, 2, True, 1024),
|
||||
(8, 16, 1, True, 4096),
|
||||
(8, 16, 2, True, 4096)])
|
||||
def test_alibi_slope(self, batch, tp_head_num, tp_num, use_dynamic, max_sequence_length):
|
||||
torch.manual_seed(0)
|
||||
true_seq_lens = torch.randint(0, 2 * max_sequence_length, (batch, 1), dtype=torch.int32)
|
||||
output_base = alibi_slopes(batch, tp_head_num * tp_num, true_seq_lens, max_sequence_length, use_dynamic)
|
||||
output_mlu = btunittests.alibi_slope_test(true_seq_lens.mlu(), batch, tp_head_num, tp_num, use_dynamic, max_sequence_length)
|
||||
diff1 = super().diff1(output_mlu, output_base)
|
||||
diff2 = super().diff2(output_mlu, output_base)
|
||||
assert diff1 <= 1e-5 and diff2 <= 1e-5
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,109 @@
|
||||
import sys
|
||||
import os
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
sys.path.append(parent_dir)
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import btunittests
|
||||
import random
|
||||
import math
|
||||
|
||||
def create_cos_sin_table(rotary_base, rotary_seq_len_, rotary_stride_, rotary_scaling, rotary_dim, interleaved, datatype):
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
inv_freq=1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 2).mlu() / rotary_dim))
|
||||
t = torch.arange(rotary_seq_len_).mlu()
|
||||
t = t / rotary_scaling
|
||||
freqs = torch.outer(t, inv_freq).mlu()
|
||||
cos_table = torch.cos(freqs).mlu()
|
||||
sin_table = torch.sin(freqs).mlu()
|
||||
if interleaved == 0:
|
||||
cos_table = torch.cat((cos_table,cos_table), dim = 1).mlu()
|
||||
sin_table = torch.cat((sin_table,sin_table), dim = 1).mlu()
|
||||
emb = torch.cat((cos_table, sin_table), dim=-1).mlu()
|
||||
else:
|
||||
cos_tmp = torch.cat((cos_table[:,0].unsqueeze(1),cos_table[:,0].unsqueeze(1)), dim = 1).mlu()
|
||||
emb = cos_tmp
|
||||
for i in range(1, cos_table.size(1)):
|
||||
cos_tmp = torch.cat((cos_table[:,i].unsqueeze(1),cos_table[:,i].unsqueeze(1)), dim = 1).mlu()
|
||||
emb = torch.cat((emb,cos_tmp), dim = 1).mlu()
|
||||
for i in range(0, cos_table.size(1)):
|
||||
sin_tmp = torch.cat((sin_table[:,i].unsqueeze(1),sin_table[:,i].unsqueeze(1)), dim = 1).mlu()
|
||||
emb = torch.cat((emb,sin_tmp), dim = 1).mlu()
|
||||
return emb
|
||||
|
||||
def get_ntk_alpha(true_seq_len, max_position_embeddings, rotary_base, rotary_dim):
|
||||
context_value = math.log(true_seq_len / max_position_embeddings, 2) + 1
|
||||
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
||||
ntk_alpha = max(ntk_alpha, 1)
|
||||
rotary_b = rotary_base * (ntk_alpha ** (rotary_dim / (rotary_dim - 2)))
|
||||
return ntk_alpha, rotary_b
|
||||
|
||||
class TestCreateCosSinTableKernel(UnitTest):
|
||||
@pytest.mark.parametrize("batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,\
|
||||
rotary_scaling_type, rotary_dim , interleaved, datatype",
|
||||
[(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.float32),
|
||||
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.float32),
|
||||
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float32),
|
||||
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float32),
|
||||
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float32),
|
||||
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float32),
|
||||
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float32),
|
||||
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float32),
|
||||
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.bfloat16),
|
||||
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.bfloat16),
|
||||
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float16),
|
||||
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float16),
|
||||
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float16),
|
||||
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float16),
|
||||
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float16),
|
||||
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float16)])
|
||||
def test_all(self, batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,
|
||||
rotary_scaling_type, rotary_dim, interleaved ,datatype):
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name and datatype == torch.bfloat16:
|
||||
datatype = torch.half
|
||||
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
seq_lens = torch.randint(size = (batch,), low = 0, high = 512,
|
||||
dtype = torch.int32).mlu()
|
||||
ntk_alpha_list = []
|
||||
rotary_emb_alpha_cached = torch.rand((batch), dtype=torch.float32).mlu()
|
||||
ref_rotary_emb_alpha_cached = rotary_emb_alpha_cached.clone()
|
||||
output_mlu = btunittests.create_cos_sin_table_test(ref_rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch_stride,
|
||||
rotary_seq_len, rotary_dim, rotary_stride, rotary_scaling,
|
||||
rotary_scaling_type, rotary_base, interleaved)
|
||||
if rotary_scaling_type != 2:
|
||||
output_base = create_cos_sin_table(rotary_base, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype)
|
||||
else:
|
||||
max_seq_len, max_index = torch.max(seq_lens, dim = 0)
|
||||
if max_seq_len > max_position_embeddings:
|
||||
for i in range(batch):
|
||||
ntk_alpha, rb = get_ntk_alpha(seq_lens[i], max_position_embeddings, rotary_base, rotary_dim)
|
||||
rotary_emb_alpha_cached[i] = ntk_alpha
|
||||
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
|
||||
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
|
||||
else:
|
||||
ntk_alpha, rb = get_ntk_alpha(max_seq_len, max_position_embeddings, rotary_base, rotary_dim)
|
||||
for i in range(batch):
|
||||
rotary_emb_alpha_cached[i] = ntk_alpha
|
||||
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
|
||||
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
|
||||
output_base = cos_sin_table
|
||||
output_base_1 = rotary_emb_alpha_cached
|
||||
diff1 = super().diff1(output_mlu, output_base)
|
||||
diff2 = super().diff2(output_mlu, output_base)
|
||||
assert diff1 < 0.003 and diff2 < 0.003
|
||||
if rotary_scaling_type == 2:
|
||||
diff1 = super().diff1(ref_rotary_emb_alpha_cached, output_base_1)
|
||||
diff2 = super().diff2(ref_rotary_emb_alpha_cached, output_base_1)
|
||||
assert diff1 < 0.003 and diff2 < 0.003
|
||||
assert diff1 < 0.003 and diff2 < 0.003
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import torch.nn as nn
|
||||
import btunittests
|
||||
|
||||
class TestEmbeddingKernel(UnitTest):
|
||||
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype", [
|
||||
(1, 1024, 10000, 128, torch.float16),
|
||||
(6, 128, 20000, 4096, torch.float16),
|
||||
(1, 10, 5000, 1024, torch.float32),
|
||||
(5, 10, 10000, 128, torch.float32)])
|
||||
def test_all(self, batch, seq, vocab_size, hidden_size, dtype):
|
||||
torch.manual_seed(0)
|
||||
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
|
||||
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size).to(dtype).mlu()
|
||||
output_base = embedding_layer(input)
|
||||
output_mlu = btunittests.embedding_test(input, embedding_layer.weight, 0, vocab_size)
|
||||
diff1 = super().diff1(output_mlu, output_base)
|
||||
diff2 = super().diff2(output_mlu, output_base)
|
||||
assert diff1 == 0 and diff2 == 0
|
||||
|
||||
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part", [
|
||||
(1, 1024, 10000, 128, torch.float16, 2000, 1000),
|
||||
(6, 128, 20000, 4096, torch.float16, 10000, 10000),
|
||||
(1, 10, 5000, 1024, torch.float32, 0, 2500),
|
||||
(5, 10, 10000, 128, torch.float32, 100, 512)])
|
||||
def test_part(self, batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part):
|
||||
torch.manual_seed(0)
|
||||
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
|
||||
input_mask = torch.ones_like(input) * vocab_size
|
||||
input = torch.where(input < vocab_offset, input_mask, input)
|
||||
input = torch.where(input >= vocab_offset + vocab_part, input_mask, input)
|
||||
embedding_layer = nn.Embedding(num_embeddings=vocab_size + 1, embedding_dim=hidden_size, padding_idx=vocab_size).to(dtype).mlu()
|
||||
output_base = embedding_layer(input)
|
||||
part_weight = torch.Tensor(embedding_layer.weight)[vocab_offset : vocab_offset + vocab_part, :]
|
||||
output_mlu = btunittests.embedding_test(input, part_weight, vocab_offset, vocab_part)
|
||||
diff1 = super().diff1(output_mlu, output_base)
|
||||
diff2 = super().diff2(output_mlu, output_base)
|
||||
assert diff1 == 0 and diff2 == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,92 @@
|
||||
import sys
|
||||
import os
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
sys.path.append(parent_dir)
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import btunittests
|
||||
import random
|
||||
|
||||
class TestSliceCuSeqlensKernel(UnitTest):
|
||||
@pytest.mark.parametrize("batch_size, parallel_num", [
|
||||
(8,3),
|
||||
(16,16),
|
||||
(32,16),
|
||||
(64,16),
|
||||
(128,16),
|
||||
(256,16),
|
||||
(512,16),
|
||||
(1024,16),
|
||||
(2048,16),
|
||||
(4096,16),
|
||||
])
|
||||
def test_all(self, batch_size, parallel_num):
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
if batch_size < parallel_num:
|
||||
return
|
||||
cu_seq_lens = torch.randint(low=0, high=1024*1024, size=(batch_size+1,), dtype=torch.int32).mlu()
|
||||
every = (batch_size + parallel_num - 1) // parallel_num
|
||||
repeat = batch_size // every
|
||||
remain = batch_size % every
|
||||
loop = repeat + (remain != 0)
|
||||
result = []
|
||||
for i in range(loop):
|
||||
elem_num = 1 + (i == loop - 1 and remain != 0 and remain or every)
|
||||
start_idx = i * every
|
||||
end_idx = start_idx + elem_num
|
||||
slice_tensor = cu_seq_lens[start_idx:end_idx]
|
||||
adjusted_tensor = slice_tensor - slice_tensor[0]
|
||||
result.append(adjusted_tensor)
|
||||
sliced_cu_seq_lens = torch.cat(result)
|
||||
sliced_cu_seq_lens_mlu = torch.zeros(batch_size + loop, dtype=torch.int32).mlu()
|
||||
btunittests.slice_cu_seq_lens_test(cu_seq_lens, sliced_cu_seq_lens_mlu, batch_size, parallel_num)
|
||||
sliced_cu_seq_lens_diff1 = super().diff1(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
|
||||
sliced_cu_seq_lens_diff2 = super().diff2(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
|
||||
assert sliced_cu_seq_lens_diff1 == 0 and sliced_cu_seq_lens_diff2 == 0
|
||||
|
||||
class TestGenerateCuSeqlensKernel(UnitTest):
|
||||
@pytest.mark.parametrize("seq_len, parallel_num, is_causal_mask, is_kv_seq_len", [
|
||||
(2154, 4, False, False),
|
||||
(3412, 3, True, False),
|
||||
(996, 7, False, True),
|
||||
(872, 4, False, False),
|
||||
(634, 12, True, False),
|
||||
(486, 4, False, True),
|
||||
(125, 6, False, False),
|
||||
])
|
||||
def test_all(self, seq_len, parallel_num, is_causal_mask, is_kv_seq_len):
|
||||
every = (seq_len + parallel_num - 1) // parallel_num
|
||||
repeat = seq_len // every
|
||||
remain = seq_len % every
|
||||
loop = repeat + (remain != 0)
|
||||
|
||||
generate_cu_seq_lens = torch.zeros(2 * loop, dtype=torch.int32).mlu()
|
||||
|
||||
rep = seq_len // loop
|
||||
rem = seq_len % loop
|
||||
base = rep + (rem != 0)
|
||||
|
||||
if is_causal_mask and is_kv_seq_len:
|
||||
for i in range(loop):
|
||||
generate_cu_seq_lens[2 * i + 1] = seq_len if i == loop - 1 else (i + 1) * base
|
||||
elif not is_kv_seq_len:
|
||||
for i in range(loop):
|
||||
generate_cu_seq_lens[2 * i + 1] = seq_len - (loop - 1) * base if i == loop - 1 else base
|
||||
elif not is_causal_mask and is_kv_seq_len:
|
||||
for i in range(loop):
|
||||
generate_cu_seq_lens[2 * i + 1] = seq_len
|
||||
|
||||
generate_cu_seq_lens_mlu = torch.zeros(2 * loop, dtype=torch.int32).mlu()
|
||||
btunittests.generate_cu_seq_lens_test(generate_cu_seq_lens_mlu, seq_len, parallel_num, is_causal_mask, is_kv_seq_len)
|
||||
generate_cu_seq_lens_diff1 = super().diff1(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
|
||||
generate_cu_seq_lens_diff2 = super().diff2(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
|
||||
assert generate_cu_seq_lens_diff1 == 0 and generate_cu_seq_lens_diff2 == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,82 @@
|
||||
import sys
|
||||
import os
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
sys.path.append(parent_dir)
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import btunittests
|
||||
import random
|
||||
|
||||
def round_func(tensor, mode=1):
|
||||
tensor = tensor.float()
|
||||
fractional_part = tensor - tensor.floor()
|
||||
even_mask = (tensor.floor() % 2 == 0)
|
||||
if mode == 1:
|
||||
rounded_tensor = torch.where(fractional_part == 0.5,
|
||||
torch.where(even_mask, tensor.floor(), tensor.ceil()),
|
||||
torch.round(tensor))
|
||||
elif mode == 2:
|
||||
rounded_tensor = torch.where(fractional_part == 0.5, tensor + 0.5, torch.round(tensor))
|
||||
elif mode == 3:
|
||||
rounded_tensor = torch.round(tensor)
|
||||
return rounded_tensor
|
||||
|
||||
def compute_with_tensor(src, batch, seq_len, head_num, head_size):
|
||||
src = src.view(batch * seq_len * head_num * 2, head_size)
|
||||
mask = torch.ones(src.size(0), dtype=bool)
|
||||
for i in range(head_num, src.size(0), head_num * 2):
|
||||
mask[i:i+head_num] = False
|
||||
src = src[mask]
|
||||
max_value_tensor, _ = torch.max(src.abs(), dim=1)
|
||||
scale = (max_value_tensor / 127.0).view(batch, seq_len, head_num)
|
||||
scale_ori_quant = 127 / max_value_tensor
|
||||
temp_tensor = src * scale_ori_quant.view(scale_ori_quant.size(0),1)
|
||||
dst = round_func(temp_tensor, 1).view(batch, seq_len, head_num, head_size)
|
||||
return scale,dst
|
||||
|
||||
class TestPerHeadQuantizeKernel(UnitTest):
|
||||
@pytest.mark.parametrize("batch, seq_len, head_num, head_size, src_dtype", [
|
||||
(17,1436,31,21, torch.float32),
|
||||
(17,1436,31,21, torch.float16),
|
||||
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float32),
|
||||
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float16),
|
||||
(2,3,5,5, torch.float32)
|
||||
])
|
||||
def test_all(self, batch, seq_len, head_num, head_size, src_dtype):
|
||||
mlu_name = torch.mlu.get_device_name()
|
||||
if "MLU3" in mlu_name and src_dtype == torch.bfloat16:
|
||||
src_dtype = torch.half
|
||||
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
hidden = head_num * head_size
|
||||
seq_stride = 2 * hidden
|
||||
scale_dtype = torch.float32
|
||||
dst_dtype = torch.int8
|
||||
dst_bs_stride = seq_len * hidden
|
||||
dst_seq_stride = hidden
|
||||
dst_head_stride = head_size
|
||||
src_bs_stride = seq_len * seq_stride
|
||||
src_seq_stride = seq_stride
|
||||
src_head_stride = head_size
|
||||
|
||||
src = torch.rand((batch, seq_len, head_num, 2 * head_size), dtype=src_dtype).mlu()
|
||||
scale, dst = compute_with_tensor(src, batch, seq_len, head_num, head_size)
|
||||
scale_mlu = torch.zeros((batch, seq_len, head_num), dtype=scale_dtype).mlu()
|
||||
dst_mlu = torch.zeros((batch, seq_len, head_num, head_size), dtype=dst_dtype).mlu()
|
||||
btunittests.per_head_quantize_test(dst_mlu, scale_mlu, src, batch, seq_len, head_num, head_size, dst_bs_stride,
|
||||
dst_seq_stride, dst_head_stride, src_bs_stride, src_seq_stride,
|
||||
src_head_stride)
|
||||
dst_diff1 = super().diff1(dst, dst_mlu)
|
||||
dst_diff2 = super().diff2(dst, dst_mlu)
|
||||
scale_diff1 = super().diff1(scale, scale_mlu)
|
||||
scale_diff2 = super().diff2(scale, scale_mlu)
|
||||
assert dst_diff1 < 0.003 and dst_diff2 < 0.003
|
||||
assert scale_diff1 < 0.003 and scale_diff2 < 0.003
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,220 @@
|
||||
import sys
|
||||
import os
|
||||
from unit_test import UnitTest
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
import btunittests
|
||||
|
||||
class ApplyRotary(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def rotate(self, x: torch.Tensor, interleaved: bool):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
y = torch.empty_like(x)
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
y[..., ::2], y[..., 1::2] = -x2, x1
|
||||
return y
|
||||
|
||||
def forward(self,
|
||||
output: torch.Tensor, # [total_seqlen, num_heads, head_size]
|
||||
input: torch.Tensor, # [total_seqlen, num_heads, head_size]
|
||||
sin_cache: torch.Tensor, # [rope_seqlen, rotary_dim] / [batch, rope_seqlen, rotary_dim]
|
||||
cos_cache: torch.Tensor, #
|
||||
seq_offsets: torch.Tensor, # [batch] / [batch, max_seqlen]
|
||||
cu_seq_lens: torch.Tensor, # [batch + 1]
|
||||
interleaved: bool,
|
||||
discrete: bool,
|
||||
dynamic: bool,
|
||||
max_seqlen: int,
|
||||
no_offset:bool,
|
||||
rope_2d:bool):
|
||||
packed = input.dim() == 3
|
||||
rope_dim = sin_cache.shape[-1]
|
||||
batch_size = cu_seq_lens.shape[0] - 1 if packed else input.shape[0]
|
||||
|
||||
if not rope_2d:
|
||||
for i in range(batch_size):
|
||||
input_i = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
|
||||
ouput_i = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
|
||||
input_i = input_i[..., 0:rope_dim]
|
||||
ouput_i = ouput_i[..., 0:rope_dim]
|
||||
start_seq_idx = 0 if no_offset else seq_offsets[i]
|
||||
sin_cache_i = sin_cache[i] if dynamic else sin_cache
|
||||
cos_cache_i = cos_cache[i] if dynamic else cos_cache
|
||||
seq = input_i.shape[0]
|
||||
if discrete:
|
||||
start_seq_idx = 0
|
||||
token_offset = seq_offsets[cu_seq_lens[i]:cu_seq_lens[i]+seq]
|
||||
sin_cache_i = sin_cache_i[token_offset]
|
||||
cos_cache_i = cos_cache_i[token_offset]
|
||||
sin_cache_i = sin_cache_i[start_seq_idx:seq+start_seq_idx]
|
||||
cos_cache_i = cos_cache_i[start_seq_idx:seq+start_seq_idx]
|
||||
rot = self.rotate(input_i, interleaved)
|
||||
|
||||
ouput_i[:] = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1)
|
||||
output[..., rope_dim:] = input[..., rope_dim:]
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
input_i_left = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
|
||||
ouput_i_left = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
|
||||
input_i_left = input_i_left[..., 0:rope_dim]
|
||||
ouput_i_left = ouput_i_left[..., 0:rope_dim]
|
||||
sin_cache_i = sin_cache[i] if dynamic else sin_cache
|
||||
cos_cache_i = cos_cache[i] if dynamic else cos_cache
|
||||
seq = input_i_left.shape[0]
|
||||
if discrete:
|
||||
token_offset = seq_offsets[0][cu_seq_lens[i]:cu_seq_lens[i]+seq]
|
||||
sin_cache_i = sin_cache_i[token_offset]
|
||||
cos_cache_i = cos_cache_i[token_offset]
|
||||
sin_cache_i = sin_cache_i[:seq]
|
||||
cos_cache_i = cos_cache_i[:seq]
|
||||
rot = self.rotate(input_i_left, interleaved)
|
||||
|
||||
ouput_i_left[:] = rot * sin_cache_i.unsqueeze(1) + input_i_left * cos_cache_i.unsqueeze(1)
|
||||
|
||||
input_i_right = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
|
||||
ouput_i_right = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
|
||||
input_i_right = input_i_right[..., rope_dim:]
|
||||
ouput_i_right = ouput_i_right[..., rope_dim:]
|
||||
sin_cache_i = sin_cache[i] if dynamic else sin_cache
|
||||
cos_cache_i = cos_cache[i] if dynamic else cos_cache
|
||||
seq = input_i_right.shape[0]
|
||||
if discrete:
|
||||
token_offset = seq_offsets[1][cu_seq_lens[i]:cu_seq_lens[i]+seq]
|
||||
sin_cache_i = sin_cache_i[token_offset]
|
||||
cos_cache_i = cos_cache_i[token_offset]
|
||||
sin_cache_i = sin_cache_i[:seq]
|
||||
cos_cache_i = cos_cache_i[:seq]
|
||||
rot = self.rotate(input_i_right, interleaved)
|
||||
|
||||
ouput_i_right[:] = rot * sin_cache_i.unsqueeze(1) + input_i_right * cos_cache_i.unsqueeze(1)
|
||||
|
||||
class TestRotaryEmbeddingKernel(UnitTest):
|
||||
@pytest.mark.parametrize(
|
||||
"batch, max_seq_len, head_num_q, head_num_kv, head_size, discrete, rotary_dim, interleaved, rope_2d, dynamic_ntk, packed, no_offset, data_type",
|
||||
[
|
||||
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float32),
|
||||
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float32),
|
||||
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float32),
|
||||
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float32),
|
||||
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, True, 128, True, False, True, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, True, 128, True, False, False, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, True, 128, True, True, False, True, True, torch.float16),
|
||||
(4, 16, 8, 4, 128, False, 128, False, False, True, True, False, torch.float32),
|
||||
(4, 16, 8, 4, 128, True, 128, False, False, True, True, False, torch.float32),
|
||||
(4, 16, 8, 4, 128, True, 128, False, True, True, True, False, torch.float32),
|
||||
])
|
||||
def testall(
|
||||
self,
|
||||
batch,
|
||||
max_seq_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_size,
|
||||
discrete,
|
||||
rotary_dim,
|
||||
interleaved,
|
||||
rope_2d,
|
||||
dynamic_ntk,
|
||||
packed,
|
||||
no_offset,
|
||||
data_type):
|
||||
torch.manual_seed(0)
|
||||
if rope_2d:
|
||||
interleaved = False
|
||||
discrete = True
|
||||
dynamic_ntk = False
|
||||
no_offset = False
|
||||
if head_num_q != 2 * head_num_kv:
|
||||
head_num_q = 2 * head_num_kv
|
||||
if rotary_dim %2 != 0:
|
||||
rotary_dim -= 1
|
||||
if no_offset:
|
||||
discrete = False
|
||||
if rope_2d:
|
||||
head_size = int(head_size if (head_size % 4 == 0) else (head_size // 4) * 4)
|
||||
rotary_dim = int(head_size / 2);
|
||||
|
||||
seq_lens = torch.randint(size=(batch, ), low=1, high=max_seq_len+1, dtype=torch.int32, device='mlu')
|
||||
total_seq_len = batch * max_seq_len
|
||||
if not packed:
|
||||
seq_lens = torch.randint(size=(batch, ), low=max_seq_len, high=max_seq_len+1, dtype=torch.int32, device='mlu')
|
||||
max_context_len = int(seq_lens.max())
|
||||
|
||||
cu_seq_lens = None
|
||||
if packed:
|
||||
cu_seq_lens = torch.cumsum(seq_lens, dim=-1, dtype=torch.int32)
|
||||
cu_seq_lens = torch.nn.functional.pad(cu_seq_lens, (1,0), "constant", 0).mlu()
|
||||
total_seq_len = int(cu_seq_lens[-1])
|
||||
|
||||
head_num_qkv = head_num_q + head_num_kv
|
||||
head_num_qk = head_num_q + head_num_kv
|
||||
|
||||
context_shape = (total_seq_len, head_num_qkv + 1, head_size) if packed else \
|
||||
(batch, max_seq_len, head_num_qkv + 1, head_size)
|
||||
context = torch.randn(size=context_shape, dtype=data_type).mlu()
|
||||
|
||||
qkv = context[..., 0 : head_num_qkv, :]
|
||||
qk = context[..., 0 : head_num_qk, :]
|
||||
seq_offsets = None
|
||||
if rope_2d:
|
||||
seq_offsets = torch.randint(size=(2, total_seq_len), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
|
||||
else:
|
||||
if discrete and not no_offset:
|
||||
seq_offsets = torch.randint(size=(total_seq_len,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
|
||||
elif not no_offset:
|
||||
seq_offsets = torch.randint(size=(batch,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
|
||||
|
||||
cos_cache = None
|
||||
sin_cache = None
|
||||
if dynamic_ntk:
|
||||
cos_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
|
||||
sin_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
|
||||
else:
|
||||
cos_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
|
||||
sin_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
|
||||
|
||||
base_output = torch.empty_like(qk)
|
||||
|
||||
apply_rotary = ApplyRotary()
|
||||
apply_rotary(base_output, qkv, sin_cache, cos_cache, \
|
||||
seq_offsets, cu_seq_lens, interleaved, discrete, dynamic_ntk, max_context_len, no_offset, rope_2d)
|
||||
|
||||
output = btunittests.rotary_embedding_test(
|
||||
qkv,
|
||||
sin_cache,
|
||||
cos_cache,
|
||||
seq_offsets,
|
||||
cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
dynamic_ntk,
|
||||
rope_2d,
|
||||
max_context_len,
|
||||
no_offset)
|
||||
|
||||
diff1 = super().diff1(output, base_output)
|
||||
diff2 = super().diff2(output, base_output)
|
||||
assert diff1 < 0.003 and diff2 < 0.003
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = pytest.main(
|
||||
[
|
||||
"-vs",
|
||||
os.path.abspath(__file__),
|
||||
"--junitxml="
|
||||
+ os.path.splitext(os.path.basename(__file__))[0]
|
||||
+ "_report.xml",
|
||||
]
|
||||
)
|
||||
exit(exit_code)
|
||||
25
torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py
Normal file
25
torch_mlu_ops-v1.3.2/tests/kernels_pytest/unit_test.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import sys
|
||||
import os
|
||||
build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib"
|
||||
sys.path.append(build_lib_dir)
|
||||
import torch
|
||||
|
||||
class UnitTest:
|
||||
def diff1(self, result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
assert result.shape == baseline.shape
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(torch.abs(baseline)).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff1 = torch.sum(error) / (denominator + eps)
|
||||
return diff1.item()
|
||||
|
||||
def diff2(self, result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(baseline**2).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps))
|
||||
return diff2.item()
|
||||
Reference in New Issue
Block a user