This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
cmake_minimum_required(VERSION 3.10)
set(CMAKE_C_COMPILER "gcc")
set(CMAKE_CXX_COMPILER "g++")
project(kernel_test)
message(STATUS "project name: ${PROJECT_NAME}")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/archive")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake/modules")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_compile_options(-std=c++17 -O3 -g -fPIC -Wall -Werror -Wextra -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unknown-pragmas)
link_directories($ENV{NEUWARE_HOME}/lib64)
link_libraries(m stdc++ dl pthread cnnl cnnl_extra cnrt cndrv "-Wl,-rpath,$ENV{NEUWARE_HOME}/lib64 -Wl,--disable-new-dtags")
include_directories($ENV{NEUWARE_HOME}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/ ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc/kernels)
function(find_torch)
# get the path and cxx11_abi flag
execute_process(
COMMAND python3 -c "import torch; print(torch.__path__[0], torch.compiled_with_cxx11_abi(), sep=';')"
RESULT_VARIABLE TORCH_NOT_FOUND
OUTPUT_VARIABLE TORCH_INFO
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_NOT_FOUND)
return()
endif()
list(GET TORCH_INFO 0 TORCH_PATH)
message(STATUS "torch path: ${TORCH_PATH}")
list(GET TORCH_INFO 1 TORCH_CXX11_ABI)
message(STATUS "torch cxx11 abi: ${TORCH_CXX11_ABI}")
set(Torch_DIR ${TORCH_PATH}/share/cmake/Torch PARENT_SCOPE)
endfunction()
# import pytorch
find_torch()
message(STATUS "Torch_DIR: ${Torch_DIR}")
find_package(Torch QUIET)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
# import torch_mlu
execute_process(
COMMAND python3 -c "import torch_mlu.utils as mlu_utils;print(mlu_utils.cmake_prefix_path)"
OUTPUT_VARIABLE Torch_MLU_MODULE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# find ops library
execute_process(
COMMAND python3 -c "import torch_mlu_ops as ops;print(ops._utils.get_custom_op_library_path())"
RESULT_VARIABLE LIBOPS_NOT_FOUND
OUTPUT_VARIABLE LIBOPS_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(LIBOPS_NOT_FOUND)
message(FATAL_ERROR "torch_mlu_ops not installed, can not find ops library.")
endif()
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${Torch_MLU_MODULE_DIR})
find_package(TorchMLU QUIET)
# torch_mlu throw [-Werror=sign-compare] error, so ignore it
add_compile_options(-Wno-sign-compare)
# TorchMLUConfig.cmake run will get TORCH_ATEN_LIBRARY-NOTFOUND, it will cause compile fail, remove it
string(REPLACE "TORCH_ATEN_LIBRARY-NOTFOUND" "" TORCH_MLU_LIBRARIES_MODIFIED "${TORCH_MLU_LIBRARIES}")
execute_process(
COMMAND python3 -c "from distutils import sysconfig; print(sysconfig.get_python_inc())"
OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(Torch_FOUND AND ((TORCH_CXX11_ABI AND USE_CXX11_ABI) OR (NOT TORCH_CXX11_ABI AND NOT USE_CXX11_ABI)))
include_directories(${PYTHON_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} ${TORCH_MLU_INCLUDE_DIRS})
link_libraries(${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} ${TORCH_MLU_LIBRARIES_MODIFIED})
file(GLOB_RECURSE TEST_SRCS "src/*.cpp" RECURSE)
execute_process(
COMMAND python3 -c "import sysconfig;print(sysconfig.get_config_var('EXT_SUFFIX'))"
OUTPUT_VARIABLE TEST_LIB_NAME_SUFFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message("${PYTHON_EXTENSION}")
set(CMAKE_SHARED_LIBRARY_PREFIX "")
set(CMAKE_SHARED_LIBRARY_SUFFIX "")
set(TEST_LIB_NAME_PREFIX "btunittests")
string(APPEND TEST_LIB_NAME "${TEST_LIB_NAME_PREFIX}${TEST_LIB_NAME_SUFFIX}")
add_library(${TEST_LIB_NAME} SHARED ${TEST_SRCS})
target_link_libraries(${TEST_LIB_NAME} ${LIBOPS_PATH} -lstdc++fs)
else()
message(STATUS "Torch not found, or torch abi is different with which you specified, will not build")
message(STATUS "if torch not found, please set env Torch_DIR to the directory containing TorchConfig.cmake")
endif()

View File

@@ -0,0 +1,17 @@
## BT_OPS测试脚本使用方式
```bash
# 测试所有测例
bash run_test.sh
```
```bash
# 测试单个测例
python3 test_测例名称.py
```
- 必须在Torch-MLU-Ops docker容器内运行。
- 测试脚本的命名规则为 `test_测例名称.py`
- 必须保证 Torch-MLU-Ops whl包正确安装。

View File

@@ -0,0 +1,10 @@
#!/bin/bash
set -e
TOP_DIR="$( cd "$( dirname "$0" )" && pwd )"
cd ${TOP_DIR}
rm -rf build > /dev/null 2>&1
cmake $TOP_DIR -Bbuild $@
cmake --build build -- -j32

View File

@@ -0,0 +1,22 @@
#!/bin/bash
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
tmo_kernel_case=$(find "${SCRIPT_DIR}" -name "test_*.py")
coverage=${1}
for sc in ${tmo_kernel_case}
do
echo -n "${sc} "
echo -n "Testing...."
if [ "${coverage}" = "coverage" ];then
coverage run -a ${sc}
else
python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1
fi
if [ $? == 0 ];then
echo -e "\033[32m success \033[0m"
else
echo -e "\033[31m failed \033[0m"
fi
done
echo "End of pytest..."

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/generate_alibi_slope.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length) {
MLU_TENSOR_CHECK_FATAL(true_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
int head_num = tp_head_num * tp_num;
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(true_seq_lens.device());
torch::Tensor output = torch::zeros(
{batch, head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
for (int tp_id = 0; tp_id < tp_num; tp_id++) {
torch::Tensor tp_slopes = torch::zeros(
{batch, tp_head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeGenerateAlibiSlope(
queue, tp_slopes.data_ptr(), true_seq_lens.data_ptr(), batch, tp_id * tp_head_num,
tp_head_num, head_num, max_sequence_length, use_dynamic));
for (int batch_id = 0; batch_id < batch; batch_id++) {
CNRT_CHECK(
cnrtMemcpyAsync((float *)output.data_ptr() + batch_id * head_num + tp_id * tp_head_num,
(float *)tp_slopes.data_ptr() + batch_id * tp_head_num,
tp_head_num * sizeof(float), queue, cnrtMemcpyDevToDev));
}
}
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,83 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/create_cos_sin_table.mluh"
#include <cnnl.h>
#include <torch/extension.h>
#include "register_api.h"
#include "utils.h"
namespace {
constexpr int DYNAMIC_NTK_SCALING = 2;
}
namespace tmo {
namespace kernel_test_api {
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved) {
MLU_TENSOR_CHECK_FATAL(rotary_emb_alpha_cached);
MLU_TENSOR_CHECK_FATAL(seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dtype =
torch2cnnlDataType(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()));
std::shared_ptr<torch::Device> dev =
std::make_shared<torch::Device>(rotary_emb_alpha_cached.device());
int batch = seq_lens.size(0);
torch::Tensor cos_sin_table =
rotary_scaling_type == DYNAMIC_NTK_SCALING
? torch::zeros({batch, rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false))
: torch::zeros({rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false));
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeCreateCosSinTable(
queue, cos_sin_table.data_ptr(),
reinterpret_cast<float *>(rotary_emb_alpha_cached.data_ptr()),
reinterpret_cast<int *>(seq_lens.data_ptr()), max_position_embeddings, batch, batch_stride,
rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, rotary_scaling_type,
interleaved, dtype));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return cos_sin_table;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,48 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(weight);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
int batch = input.size(0);
int seq = input.size(1);
int total_vocab_size = weight.size(0);
int hidden_size = weight.size(1);
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(weight.dtype()));
torch::Tensor output = torch::zeros(
{batch, seq, hidden_size},
torch::dtype(torch::typeMetaToScalarType(weight.dtype())).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeEmbedding(queue, weight.data_ptr(), input.data_ptr(),
output.data_ptr(), dtype, vocab_offset, vocab_part,
total_vocab_size, hidden_size, batch * seq));
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/operate_cu_seq_lens.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num) {
MLU_TENSOR_CHECK_FATAL(cu_seq_lens);
MLU_TENSOR_CHECK_FATAL(sliced_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeSliceCuSeqlens(queue, (int *)cu_seq_lens.data_ptr(),
(int *)sliced_cu_seq_lens.data_ptr(), batch,
parallel_num));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len) {
MLU_TENSOR_CHECK_FATAL(gen_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeGenerateCuSeqlens(queue, (int *)gen_cu_seq_lens.data_ptr(), seq_len,
parallel_num, is_causal_mask, is_kv_seq_len));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,67 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/quantize.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride) {
MLU_TENSOR_CHECK_FATAL(dst);
MLU_TENSOR_CHECK_FATAL(scale);
MLU_TENSOR_CHECK_FATAL(src);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dst_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(dst.dtype()));
cnnlDataType_t scale_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(scale.dtype()));
cnnlDataType_t src_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(src.dtype()));
size_t io_bytes =
bs * seq_len * head_num * head_size * (dtype_size_map[src_dtype] + dtype_size_map[dst_dtype]);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeMluQuantizePerHead(
queue, (void *)dst.data_ptr(), (void *)scale.data_ptr(), (void *)src.data_ptr(), dst_dtype,
scale_dtype, src_dtype, bs, seq_len, head_num, head_size, dst_bs_stride, dst_seq_stride,
dst_head_stride, src_bs_stride, src_seq_stride, src_head_stride));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,82 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_REGISTER_API_H_
#define TEST_KERNELS_PYTEST_REGISTER_API_H_
#include <torch/extension.h>
#include <torch/torch.h>
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part);
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length);
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved);
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num);
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len);
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride);
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset);
} // namespace kernel_test_api
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_REGISTER_API_H_

View File

@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/extension.h>
#include "register_api.h"
PYBIND11_MODULE(btunittests, m) {
m.def("embedding_test", &tmo::kernel_test_api::embedding_test, "Test embedding kernel function.");
m.def("alibi_slope_test", &tmo::kernel_test_api::alibi_slope_test,
"Test alibi_slope kernel function.");
m.def("create_cos_sin_table_test", &tmo::kernel_test_api::create_cos_sin_table_test,
"Test create_cos_sin_table_test kernel function.");
m.def("slice_cu_seq_lens_test", &tmo::kernel_test_api::slice_cu_seq_lens_test,
"Test slice_cu_seq_lens_test kernel function.");
m.def("generate_cu_seq_lens_test", &tmo::kernel_test_api::generate_cu_seq_lens_test,
"Test generate_cu_seq_lens_test kernel function.");
m.def("per_head_quantize_test", &tmo::kernel_test_api::per_head_quantize_test,
"Test per_head_quantize_test kernel function.");
m.def("rotary_embedding_test", &tmo::kernel_test_api::rotary_embedding_test,
"Test rotary_embedding kernel function.");
}

View File

@@ -0,0 +1,202 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/rotary_embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(sin_table);
MLU_TENSOR_CHECK_FATAL(cos_table);
bool has_seq_offsets = seq_offsets.has_value();
bool has_cu_seq_lens = cu_seq_lens.has_value();
int *seq_offsets_ptr =
has_seq_offsets ? reinterpret_cast<int *>(seq_offsets.value().data_ptr()) : nullptr;
int *cu_seq_lens_ptr =
has_cu_seq_lens ? reinterpret_cast<int *>(cu_seq_lens.value().data_ptr()) : nullptr;
int batch = 0;
int total_seq_len = 0;
int head_size = input.size(-1);
if (input.dim() == 3) { // pack mode
TORCH_CHECK(has_cu_seq_lens,
"input has 3 dims: (total_seq_len, head_num, head_size),"
" which means pack mode, cu_seqlens should not be None");
total_seq_len = input.size(0);
batch = cu_seq_lens.value().size(0) - 1;
} else if (input.dim() == 4) {
TORCH_CHECK(!has_cu_seq_lens,
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, cu_seqlens should be None");
TORCH_CHECK(max_seq_len == input.size(1),
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, max_seqlen must be equtals to input.size(1)");
batch = input.size(0);
total_seq_len = batch * input.size(1);
} else {
TORCH_CHECK(false, "input only support 3 or 4 dims");
}
const int rope_seqlen = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
const int rope_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (has_seq_offsets) {
if (discrete) {
CHECK_SHAPE(seq_offsets.value(), total_seq_len);
} else {
CHECK_SHAPE(seq_offsets.value(), batch);
}
}
if (dynamic_ntk) {
CHECK_SHAPE(sin_table, batch, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, batch, rope_seqlen, rope_dim);
} else {
CHECK_SHAPE(sin_table, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, rope_seqlen, rope_dim);
}
#endif
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
if (dynamic_ntk) {
TORCH_CHECK(sin_table.stride(1) == cos_table.stride(1),
"sin_table second stride must be equal to cos_table second stride");
} else {
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
"sin_table first stride must be equal to cos_table second stride");
}
if (has_seq_offsets) {
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
}
if (has_cu_seq_lens) {
TORCH_CHECK(cu_seq_lens.value().is_contiguous(), "cu_seq_lens must be contiguous");
}
// auto output = input;
// if (is_qkv) {
// }
int dims = input.dim();
int head_num = input.size(dims - 2); // head_num_qk
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
int head_dim = input.size(dims - 1);
int rotary_seq_len = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
int rotary_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (rope_2d) {
total_seq_len = input.size(0);
rotary_seq_len *= 2;
}
#endif
int rotary_stride = dynamic_ntk ? sin_table.stride(1) : sin_table.stride(0);
size_t input_seq_stride = input.stride(dims - 3);
size_t input_head_stride = input.stride(dims - 2);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
#if 0
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
std::cout
<< " batch * max_seq_len : " << batch * max_seq_len
<< " max_seq_len : " << max_seq_len
<< " total_seq_len : " << total_seq_len
<< " head_num : " << head_num
<< " head_size : " << head_size << "\n";
// at::IntArrayRef output_shape = {batch, max_seq_len, head_num, head_size};
// if (input.dim() == 3) {
// output_shape = {batch * max_seq_len, head_num, head_size};
// }
#endif
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(input.dtype()));
auto output = input;
size_t output_seq_stride = output.stride(dims - 3);
size_t output_head_stride = output.stride(dims - 2);
#if 0
std::cout << " <<<< batch: " << batch << "\n";
std::cout << " <<<< max_seq_len: " << max_seq_len << "\n";
std::cout << " <<<< total_seq_len: " << total_seq_len << "\n";
std::cout << " <<<< head_num: " << head_num << "\n";
std::cout << " <<<< head_size: " << head_size << "\n";
std::cout << " <<<< rotary_seq_len_: " << rotary_seq_len << "\n";
std::cout << " <<<< rotary_dim: " << rotary_dim << "\n";
std::cout << " <<<< rotary_stride: " << rotary_stride << "\n";
std::cout << " <<<< input_seq_stride: " << input_seq_stride << "\n";
std::cout << " <<<< input_head_stride: " << input_head_stride << "\n";
std::cout << " <<<< output_seq_stride: " << output_seq_stride << "\n";
std::cout << " <<<< output_head_stride: " << output_head_stride << "\n";
std::cout << " <<<< interleaved: " << interleaved << "\n";
std::cout << " <<<< discrete: " << discrete << "\n";
std::cout << " <<<< dynamic_ntk: " << dynamic_ntk << "\n";
#endif
size_t io_bytes =
(size_t)total_seq_len * head_num * (head_size + rotary_dim) * dtype_size_map[dtype] * 2;
io_bytes += (discrete && !no_offset) ? total_seq_len * sizeof(int) : batch * sizeof(int);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
if (rope_2d) {
TMO_KERNEL_CHECK_FATAL(invokeGlm6BRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, total_seq_len, head_num, head_size,
rotary_seq_len, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, dtype));
} else {
TMO_KERNEL_CHECK_FATAL(invokeRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, head_num, head_size, rotary_seq_len,
rotary_dim, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, discrete, dynamic_ntk, dtype));
}
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,133 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
#define TEST_KERNELS_PYTEST_UTILS_H_
#include <cnrt.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include "aten/cnnl/cnnlHandle.h"
#include "common/utils.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "kernels/kernel_utils.h"
namespace tmo {
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
switch (dtype) {
case torch::kFloat32:
return CNNL_DTYPE_FLOAT;
case torch::kFloat16:
return CNNL_DTYPE_HALF;
case torch::kFloat64:
return CNNL_DTYPE_DOUBLE;
case torch::kInt8:
return CNNL_DTYPE_INT8;
case torch::kInt16:
return CNNL_DTYPE_INT16;
case torch::kInt32:
return CNNL_DTYPE_INT32;
case torch::kInt64:
return CNNL_DTYPE_INT64;
case torch::kUInt8:
return CNNL_DTYPE_UINT8;
case torch::kBool:
return CNNL_DTYPE_BOOL;
case torch::kBFloat16:
return CNNL_DTYPE_BFLOAT16;
default:
throw std::runtime_error("Unsupported torch::Dtype");
}
}
static constexpr int dtype_size_map[] = {
[CNNL_DTYPE_INVALID] = 0,
[CNNL_DTYPE_HALF] = 2,
[CNNL_DTYPE_FLOAT] = 4,
[CNNL_DTYPE_INT8] = 1,
[CNNL_DTYPE_INT16] = 2,
[CNNL_DTYPE_INT31] = 4,
[CNNL_DTYPE_INT32] = 4,
[CNNL_DTYPE_UINT8] = 1,
[CNNL_DTYPE_BOOL] = 1,
[CNNL_DTYPE_INT64] = 8,
[10] = 0,
[CNNL_DTYPE_UINT32] = 4,
[CNNL_DTYPE_UINT64] = 8,
[CNNL_DTYPE_UINT16] = 2,
[CNNL_DTYPE_DOUBLE] = 8,
[CNNL_DTYPE_COMPLEX_HALF] = 4,
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
[CNNL_DTYPE_BFLOAT16] = 2,
};
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
switch (dtype) {
case CNNL_DTYPE_FLOAT:
return torch::kFloat32;
case CNNL_DTYPE_HALF:
return torch::kFloat16;
case CNNL_DTYPE_DOUBLE:
return torch::kFloat64;
case CNNL_DTYPE_INT8:
return torch::kInt8;
case CNNL_DTYPE_INT16:
return torch::kInt16;
case CNNL_DTYPE_INT32:
return torch::kInt32;
case CNNL_DTYPE_INT64:
return torch::kInt64;
case CNNL_DTYPE_UINT8:
return torch::kUInt8;
case CNNL_DTYPE_BOOL:
return torch::kBool;
case CNNL_DTYPE_BFLOAT16:
return torch::kBFloat16;
default:
throw std::runtime_error("Unsupported cnnlDataType_t");
}
}
static float getBandWidth() {
int card = -1;
CNRT_CHECK(cnrtGetDevice(&card));
if (cndevInit(0) != CNDEV_SUCCESS) {
abort();
}
cndevDDRInfo_t ddrinfo;
ddrinfo.version = CNDEV_VERSION_5;
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
abort();
}
double band_width = ddrinfo.bandWidth;
double band_width_decimal = ddrinfo.bandWidthDecimal;
do {
band_width_decimal /= 10;
} while (band_width_decimal > 1);
return float(band_width + band_width_decimal);
}
static void print_info(float time_usec, size_t io_bytes) {
float io_bandwidth = getBandWidth();
std::cout << "kernel time: " << time_usec << "us" << std::endl;
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
}
#define MLU_TENSOR_CHECK_FATAL(tensor) \
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
}
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_UTILS_H_

View File

@@ -0,0 +1,75 @@
import sys
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import math
def alibi_slopes(batch,
num_heads,
true_seq_lens,
train_seq_len,
use_dynamic):
scale = 1.0
if use_dynamic:
# dynamic ntk factor according to actual sequence length
a0 = 1.0
# train_seq_len = 2048
dynamic_seq_len = true_seq_lens # [batch, 1]
a = a0 * dynamic_seq_len / train_seq_len # [batch, 1]
a = a.masked_fill(a < 1.0, 1.0) # dynamic step 1: dynamic ntk scaling factor
scale = a ** (1.0 / (num_heads-1)) # dynamic step 2: coefficient b, for computation convenience
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32
)
if use_dynamic:
base = base / scale # dynamic step 3: divide b to alibi base
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if use_dynamic:
slopes = slopes * scale # dynamic step 4: fix alibi bias m_h by multiplying b
if closest_power_of_2 != num_heads: # todo: fix ntk when num_heads is not power of 2
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
extra_slopes = torch.pow(extra_base, extra_powers)
if use_dynamic:
extra_slopes = extra_slopes.unsqueeze(0)
extra_slopes = extra_slopes.repeat(batch, 1)
slopes = torch.cat([slopes, extra_slopes], dim=-1)
if not use_dynamic:
slopes = slopes.unsqueeze(0)
slopes = slopes.repeat(batch, 1)
return slopes
class TestAlibiSlopeKernel(UnitTest):
@pytest.mark.parametrize("batch, tp_head_num, tp_num, use_dynamic, max_sequence_length", [
(1, 11, 1, False, 1024),
(1, 11, 2, False, 1024),
(8, 16, 1, False, 4096),
(8, 16, 2, False, 4096),
(1, 11, 1, True, 1024),
(1, 11, 2, True, 1024),
(8, 16, 1, True, 4096),
(8, 16, 2, True, 4096)])
def test_alibi_slope(self, batch, tp_head_num, tp_num, use_dynamic, max_sequence_length):
torch.manual_seed(0)
true_seq_lens = torch.randint(0, 2 * max_sequence_length, (batch, 1), dtype=torch.int32)
output_base = alibi_slopes(batch, tp_head_num * tp_num, true_seq_lens, max_sequence_length, use_dynamic)
output_mlu = btunittests.alibi_slope_test(true_seq_lens.mlu(), batch, tp_head_num, tp_num, use_dynamic, max_sequence_length)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 <= 1e-5 and diff2 <= 1e-5
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,109 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
import math
def create_cos_sin_table(rotary_base, rotary_seq_len_, rotary_stride_, rotary_scaling, rotary_dim, interleaved, datatype):
torch.manual_seed(0)
random.seed(0)
inv_freq=1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 2).mlu() / rotary_dim))
t = torch.arange(rotary_seq_len_).mlu()
t = t / rotary_scaling
freqs = torch.outer(t, inv_freq).mlu()
cos_table = torch.cos(freqs).mlu()
sin_table = torch.sin(freqs).mlu()
if interleaved == 0:
cos_table = torch.cat((cos_table,cos_table), dim = 1).mlu()
sin_table = torch.cat((sin_table,sin_table), dim = 1).mlu()
emb = torch.cat((cos_table, sin_table), dim=-1).mlu()
else:
cos_tmp = torch.cat((cos_table[:,0].unsqueeze(1),cos_table[:,0].unsqueeze(1)), dim = 1).mlu()
emb = cos_tmp
for i in range(1, cos_table.size(1)):
cos_tmp = torch.cat((cos_table[:,i].unsqueeze(1),cos_table[:,i].unsqueeze(1)), dim = 1).mlu()
emb = torch.cat((emb,cos_tmp), dim = 1).mlu()
for i in range(0, cos_table.size(1)):
sin_tmp = torch.cat((sin_table[:,i].unsqueeze(1),sin_table[:,i].unsqueeze(1)), dim = 1).mlu()
emb = torch.cat((emb,sin_tmp), dim = 1).mlu()
return emb
def get_ntk_alpha(true_seq_len, max_position_embeddings, rotary_base, rotary_dim):
context_value = math.log(true_seq_len / max_position_embeddings, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
rotary_b = rotary_base * (ntk_alpha ** (rotary_dim / (rotary_dim - 2)))
return ntk_alpha, rotary_b
class TestCreateCosSinTableKernel(UnitTest):
@pytest.mark.parametrize("batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,\
rotary_scaling_type, rotary_dim , interleaved, datatype",
[(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.float32),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.float32),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float32),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float32),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float32),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float32),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float32),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float32),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 0, torch.bfloat16),
(8, 512, 0, 267, 128, 10000, 1, 1, 64, 1, torch.bfloat16),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 0, torch.float16),
(4, 256, 11456, 179, 64, 10000, 1, 2, 32, 1, torch.float16),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 0, torch.float16),
(4, 256, 34368, 179, 192, 10000, 1, 2, 96, 1, torch.float16),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 0, torch.float16),
(4, 256, 45824, 179, 256, 10000, 1, 2, 128, 1, torch.float16)])
def test_all(self, batch, max_position_embeddings, batch_stride, rotary_seq_len, rotary_stride, rotary_base, rotary_scaling,
rotary_scaling_type, rotary_dim, interleaved ,datatype):
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and datatype == torch.bfloat16:
datatype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
torch.manual_seed(0)
random.seed(0)
seq_lens = torch.randint(size = (batch,), low = 0, high = 512,
dtype = torch.int32).mlu()
ntk_alpha_list = []
rotary_emb_alpha_cached = torch.rand((batch), dtype=torch.float32).mlu()
ref_rotary_emb_alpha_cached = rotary_emb_alpha_cached.clone()
output_mlu = btunittests.create_cos_sin_table_test(ref_rotary_emb_alpha_cached, seq_lens, max_position_embeddings, batch_stride,
rotary_seq_len, rotary_dim, rotary_stride, rotary_scaling,
rotary_scaling_type, rotary_base, interleaved)
if rotary_scaling_type != 2:
output_base = create_cos_sin_table(rotary_base, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype)
else:
max_seq_len, max_index = torch.max(seq_lens, dim = 0)
if max_seq_len > max_position_embeddings:
for i in range(batch):
ntk_alpha, rb = get_ntk_alpha(seq_lens[i], max_position_embeddings, rotary_base, rotary_dim)
rotary_emb_alpha_cached[i] = ntk_alpha
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
else:
ntk_alpha, rb = get_ntk_alpha(max_seq_len, max_position_embeddings, rotary_base, rotary_dim)
for i in range(batch):
rotary_emb_alpha_cached[i] = ntk_alpha
cos_sin_table_tmp = create_cos_sin_table(rb, rotary_seq_len, rotary_stride, rotary_scaling, rotary_dim, interleaved, datatype).unsqueeze(0)
cos_sin_table = torch.cat((cos_sin_table,cos_sin_table_tmp),dim = 0).mlu() if i != 0 else cos_sin_table_tmp
output_base = cos_sin_table
output_base_1 = rotary_emb_alpha_cached
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 < 0.003 and diff2 < 0.003
if rotary_scaling_type == 2:
diff1 = super().diff1(ref_rotary_emb_alpha_cached, output_base_1)
diff2 = super().diff2(ref_rotary_emb_alpha_cached, output_base_1)
assert diff1 < 0.003 and diff2 < 0.003
assert diff1 < 0.003 and diff2 < 0.003
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,46 @@
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import torch.nn as nn
import btunittests
class TestEmbeddingKernel(UnitTest):
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype", [
(1, 1024, 10000, 128, torch.float16),
(6, 128, 20000, 4096, torch.float16),
(1, 10, 5000, 1024, torch.float32),
(5, 10, 10000, 128, torch.float32)])
def test_all(self, batch, seq, vocab_size, hidden_size, dtype):
torch.manual_seed(0)
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size).to(dtype).mlu()
output_base = embedding_layer(input)
output_mlu = btunittests.embedding_test(input, embedding_layer.weight, 0, vocab_size)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 == 0 and diff2 == 0
@pytest.mark.parametrize("batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part", [
(1, 1024, 10000, 128, torch.float16, 2000, 1000),
(6, 128, 20000, 4096, torch.float16, 10000, 10000),
(1, 10, 5000, 1024, torch.float32, 0, 2500),
(5, 10, 10000, 128, torch.float32, 100, 512)])
def test_part(self, batch, seq, vocab_size, hidden_size, dtype, vocab_offset, vocab_part):
torch.manual_seed(0)
input = torch.randint(0, vocab_size, (batch, seq), dtype=torch.int32).mlu()
input_mask = torch.ones_like(input) * vocab_size
input = torch.where(input < vocab_offset, input_mask, input)
input = torch.where(input >= vocab_offset + vocab_part, input_mask, input)
embedding_layer = nn.Embedding(num_embeddings=vocab_size + 1, embedding_dim=hidden_size, padding_idx=vocab_size).to(dtype).mlu()
output_base = embedding_layer(input)
part_weight = torch.Tensor(embedding_layer.weight)[vocab_offset : vocab_offset + vocab_part, :]
output_mlu = btunittests.embedding_test(input, part_weight, vocab_offset, vocab_part)
diff1 = super().diff1(output_mlu, output_base)
diff2 = super().diff2(output_mlu, output_base)
assert diff1 == 0 and diff2 == 0
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,92 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
class TestSliceCuSeqlensKernel(UnitTest):
@pytest.mark.parametrize("batch_size, parallel_num", [
(8,3),
(16,16),
(32,16),
(64,16),
(128,16),
(256,16),
(512,16),
(1024,16),
(2048,16),
(4096,16),
])
def test_all(self, batch_size, parallel_num):
torch.manual_seed(0)
random.seed(0)
if batch_size < parallel_num:
return
cu_seq_lens = torch.randint(low=0, high=1024*1024, size=(batch_size+1,), dtype=torch.int32).mlu()
every = (batch_size + parallel_num - 1) // parallel_num
repeat = batch_size // every
remain = batch_size % every
loop = repeat + (remain != 0)
result = []
for i in range(loop):
elem_num = 1 + (i == loop - 1 and remain != 0 and remain or every)
start_idx = i * every
end_idx = start_idx + elem_num
slice_tensor = cu_seq_lens[start_idx:end_idx]
adjusted_tensor = slice_tensor - slice_tensor[0]
result.append(adjusted_tensor)
sliced_cu_seq_lens = torch.cat(result)
sliced_cu_seq_lens_mlu = torch.zeros(batch_size + loop, dtype=torch.int32).mlu()
btunittests.slice_cu_seq_lens_test(cu_seq_lens, sliced_cu_seq_lens_mlu, batch_size, parallel_num)
sliced_cu_seq_lens_diff1 = super().diff1(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
sliced_cu_seq_lens_diff2 = super().diff2(sliced_cu_seq_lens, sliced_cu_seq_lens_mlu)
assert sliced_cu_seq_lens_diff1 == 0 and sliced_cu_seq_lens_diff2 == 0
class TestGenerateCuSeqlensKernel(UnitTest):
@pytest.mark.parametrize("seq_len, parallel_num, is_causal_mask, is_kv_seq_len", [
(2154, 4, False, False),
(3412, 3, True, False),
(996, 7, False, True),
(872, 4, False, False),
(634, 12, True, False),
(486, 4, False, True),
(125, 6, False, False),
])
def test_all(self, seq_len, parallel_num, is_causal_mask, is_kv_seq_len):
every = (seq_len + parallel_num - 1) // parallel_num
repeat = seq_len // every
remain = seq_len % every
loop = repeat + (remain != 0)
generate_cu_seq_lens = torch.zeros(2 * loop, dtype=torch.int32).mlu()
rep = seq_len // loop
rem = seq_len % loop
base = rep + (rem != 0)
if is_causal_mask and is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len if i == loop - 1 else (i + 1) * base
elif not is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len - (loop - 1) * base if i == loop - 1 else base
elif not is_causal_mask and is_kv_seq_len:
for i in range(loop):
generate_cu_seq_lens[2 * i + 1] = seq_len
generate_cu_seq_lens_mlu = torch.zeros(2 * loop, dtype=torch.int32).mlu()
btunittests.generate_cu_seq_lens_test(generate_cu_seq_lens_mlu, seq_len, parallel_num, is_causal_mask, is_kv_seq_len)
generate_cu_seq_lens_diff1 = super().diff1(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
generate_cu_seq_lens_diff2 = super().diff2(generate_cu_seq_lens, generate_cu_seq_lens_mlu)
assert generate_cu_seq_lens_diff1 == 0 and generate_cu_seq_lens_diff2 == 0
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,82 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
import random
def round_func(tensor, mode=1):
tensor = tensor.float()
fractional_part = tensor - tensor.floor()
even_mask = (tensor.floor() % 2 == 0)
if mode == 1:
rounded_tensor = torch.where(fractional_part == 0.5,
torch.where(even_mask, tensor.floor(), tensor.ceil()),
torch.round(tensor))
elif mode == 2:
rounded_tensor = torch.where(fractional_part == 0.5, tensor + 0.5, torch.round(tensor))
elif mode == 3:
rounded_tensor = torch.round(tensor)
return rounded_tensor
def compute_with_tensor(src, batch, seq_len, head_num, head_size):
src = src.view(batch * seq_len * head_num * 2, head_size)
mask = torch.ones(src.size(0), dtype=bool)
for i in range(head_num, src.size(0), head_num * 2):
mask[i:i+head_num] = False
src = src[mask]
max_value_tensor, _ = torch.max(src.abs(), dim=1)
scale = (max_value_tensor / 127.0).view(batch, seq_len, head_num)
scale_ori_quant = 127 / max_value_tensor
temp_tensor = src * scale_ori_quant.view(scale_ori_quant.size(0),1)
dst = round_func(temp_tensor, 1).view(batch, seq_len, head_num, head_size)
return scale,dst
class TestPerHeadQuantizeKernel(UnitTest):
@pytest.mark.parametrize("batch, seq_len, head_num, head_size, src_dtype", [
(17,1436,31,21, torch.float32),
(17,1436,31,21, torch.float16),
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float32),
(random.randint(1,32),random.randint(16,2048),random.randint(1,32),random.randint(16,256),torch.float16),
(2,3,5,5, torch.float32)
])
def test_all(self, batch, seq_len, head_num, head_size, src_dtype):
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and src_dtype == torch.bfloat16:
src_dtype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
torch.manual_seed(0)
random.seed(0)
hidden = head_num * head_size
seq_stride = 2 * hidden
scale_dtype = torch.float32
dst_dtype = torch.int8
dst_bs_stride = seq_len * hidden
dst_seq_stride = hidden
dst_head_stride = head_size
src_bs_stride = seq_len * seq_stride
src_seq_stride = seq_stride
src_head_stride = head_size
src = torch.rand((batch, seq_len, head_num, 2 * head_size), dtype=src_dtype).mlu()
scale, dst = compute_with_tensor(src, batch, seq_len, head_num, head_size)
scale_mlu = torch.zeros((batch, seq_len, head_num), dtype=scale_dtype).mlu()
dst_mlu = torch.zeros((batch, seq_len, head_num, head_size), dtype=dst_dtype).mlu()
btunittests.per_head_quantize_test(dst_mlu, scale_mlu, src, batch, seq_len, head_num, head_size, dst_bs_stride,
dst_seq_stride, dst_head_stride, src_bs_stride, src_seq_stride,
src_head_stride)
dst_diff1 = super().diff1(dst, dst_mlu)
dst_diff2 = super().diff2(dst, dst_mlu)
scale_diff1 = super().diff1(scale, scale_mlu)
scale_diff2 = super().diff2(scale, scale_mlu)
assert dst_diff1 < 0.003 and dst_diff2 < 0.003
assert scale_diff1 < 0.003 and scale_diff2 < 0.003
if __name__ == '__main__':
exit_code = pytest.main(['-vs', os.path.abspath(__file__), '--junitxml=' + os.path.splitext(os.path.basename(__file__))[0] + '_report.xml'])
exit(exit_code)

View File

@@ -0,0 +1,220 @@
import sys
import os
from unit_test import UnitTest
import pytest
import torch
import torch_mlu
import btunittests
class ApplyRotary(torch.nn.Module):
def __init__(self):
super().__init__()
def rotate(self, x: torch.Tensor, interleaved: bool):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
y = torch.empty_like(x)
x1, x2 = x[..., ::2], x[..., 1::2]
y[..., ::2], y[..., 1::2] = -x2, x1
return y
def forward(self,
output: torch.Tensor, # [total_seqlen, num_heads, head_size]
input: torch.Tensor, # [total_seqlen, num_heads, head_size]
sin_cache: torch.Tensor, # [rope_seqlen, rotary_dim] / [batch, rope_seqlen, rotary_dim]
cos_cache: torch.Tensor, #
seq_offsets: torch.Tensor, # [batch] / [batch, max_seqlen]
cu_seq_lens: torch.Tensor, # [batch + 1]
interleaved: bool,
discrete: bool,
dynamic: bool,
max_seqlen: int,
no_offset:bool,
rope_2d:bool):
packed = input.dim() == 3
rope_dim = sin_cache.shape[-1]
batch_size = cu_seq_lens.shape[0] - 1 if packed else input.shape[0]
if not rope_2d:
for i in range(batch_size):
input_i = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i = input_i[..., 0:rope_dim]
ouput_i = ouput_i[..., 0:rope_dim]
start_seq_idx = 0 if no_offset else seq_offsets[i]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i.shape[0]
if discrete:
start_seq_idx = 0
token_offset = seq_offsets[cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[start_seq_idx:seq+start_seq_idx]
cos_cache_i = cos_cache_i[start_seq_idx:seq+start_seq_idx]
rot = self.rotate(input_i, interleaved)
ouput_i[:] = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1)
output[..., rope_dim:] = input[..., rope_dim:]
else:
for i in range(batch_size):
input_i_left = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i_left = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i_left = input_i_left[..., 0:rope_dim]
ouput_i_left = ouput_i_left[..., 0:rope_dim]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i_left.shape[0]
if discrete:
token_offset = seq_offsets[0][cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[:seq]
cos_cache_i = cos_cache_i[:seq]
rot = self.rotate(input_i_left, interleaved)
ouput_i_left[:] = rot * sin_cache_i.unsqueeze(1) + input_i_left * cos_cache_i.unsqueeze(1)
input_i_right = input[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else input[i]
ouput_i_right = output[cu_seq_lens[i] : cu_seq_lens[i + 1]] if packed else output[i]
input_i_right = input_i_right[..., rope_dim:]
ouput_i_right = ouput_i_right[..., rope_dim:]
sin_cache_i = sin_cache[i] if dynamic else sin_cache
cos_cache_i = cos_cache[i] if dynamic else cos_cache
seq = input_i_right.shape[0]
if discrete:
token_offset = seq_offsets[1][cu_seq_lens[i]:cu_seq_lens[i]+seq]
sin_cache_i = sin_cache_i[token_offset]
cos_cache_i = cos_cache_i[token_offset]
sin_cache_i = sin_cache_i[:seq]
cos_cache_i = cos_cache_i[:seq]
rot = self.rotate(input_i_right, interleaved)
ouput_i_right[:] = rot * sin_cache_i.unsqueeze(1) + input_i_right * cos_cache_i.unsqueeze(1)
class TestRotaryEmbeddingKernel(UnitTest):
@pytest.mark.parametrize(
"batch, max_seq_len, head_num_q, head_num_kv, head_size, discrete, rotary_dim, interleaved, rope_2d, dynamic_ntk, packed, no_offset, data_type",
[
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float32),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float32),
(4, 16, 8, 4, 128, False, 128, False, False, True, False, True, torch.float16),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, False, True, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, False, False, False, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, False, False, True, True, torch.float16),
(4, 16, 8, 4, 128, True, 128, True, True, False, True, True, torch.float16),
(4, 16, 8, 4, 128, False, 128, False, False, True, True, False, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, False, True, True, False, torch.float32),
(4, 16, 8, 4, 128, True, 128, False, True, True, True, False, torch.float32),
])
def testall(
self,
batch,
max_seq_len,
head_num_q,
head_num_kv,
head_size,
discrete,
rotary_dim,
interleaved,
rope_2d,
dynamic_ntk,
packed,
no_offset,
data_type):
torch.manual_seed(0)
if rope_2d:
interleaved = False
discrete = True
dynamic_ntk = False
no_offset = False
if head_num_q != 2 * head_num_kv:
head_num_q = 2 * head_num_kv
if rotary_dim %2 != 0:
rotary_dim -= 1
if no_offset:
discrete = False
if rope_2d:
head_size = int(head_size if (head_size % 4 == 0) else (head_size // 4) * 4)
rotary_dim = int(head_size / 2);
seq_lens = torch.randint(size=(batch, ), low=1, high=max_seq_len+1, dtype=torch.int32, device='mlu')
total_seq_len = batch * max_seq_len
if not packed:
seq_lens = torch.randint(size=(batch, ), low=max_seq_len, high=max_seq_len+1, dtype=torch.int32, device='mlu')
max_context_len = int(seq_lens.max())
cu_seq_lens = None
if packed:
cu_seq_lens = torch.cumsum(seq_lens, dim=-1, dtype=torch.int32)
cu_seq_lens = torch.nn.functional.pad(cu_seq_lens, (1,0), "constant", 0).mlu()
total_seq_len = int(cu_seq_lens[-1])
head_num_qkv = head_num_q + head_num_kv
head_num_qk = head_num_q + head_num_kv
context_shape = (total_seq_len, head_num_qkv + 1, head_size) if packed else \
(batch, max_seq_len, head_num_qkv + 1, head_size)
context = torch.randn(size=context_shape, dtype=data_type).mlu()
qkv = context[..., 0 : head_num_qkv, :]
qk = context[..., 0 : head_num_qk, :]
seq_offsets = None
if rope_2d:
seq_offsets = torch.randint(size=(2, total_seq_len), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
else:
if discrete and not no_offset:
seq_offsets = torch.randint(size=(total_seq_len,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
elif not no_offset:
seq_offsets = torch.randint(size=(batch,), low=1, high=max_seq_len + 1, dtype=torch.int32, device='mlu')
cos_cache = None
sin_cache = None
if dynamic_ntk:
cos_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
sin_cache = torch.randn(size=(batch, max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
else:
cos_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
sin_cache = torch.randn(size=(max_seq_len * 2, rotary_dim), dtype=data_type).mlu()
base_output = torch.empty_like(qk)
apply_rotary = ApplyRotary()
apply_rotary(base_output, qkv, sin_cache, cos_cache, \
seq_offsets, cu_seq_lens, interleaved, discrete, dynamic_ntk, max_context_len, no_offset, rope_2d)
output = btunittests.rotary_embedding_test(
qkv,
sin_cache,
cos_cache,
seq_offsets,
cu_seq_lens,
interleaved,
discrete,
dynamic_ntk,
rope_2d,
max_context_len,
no_offset)
diff1 = super().diff1(output, base_output)
diff2 = super().diff2(output, base_output)
assert diff1 < 0.003 and diff2 < 0.003
if __name__ == "__main__":
exit_code = pytest.main(
[
"-vs",
os.path.abspath(__file__),
"--junitxml="
+ os.path.splitext(os.path.basename(__file__))[0]
+ "_report.xml",
]
)
exit(exit_code)

View File

@@ -0,0 +1,25 @@
import sys
import os
build_lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/build/lib"
sys.path.append(build_lib_dir)
import torch
class UnitTest:
def diff1(self, result: torch.Tensor, baseline: torch.Tensor):
result = result.flatten().float().to('cpu')
baseline = baseline.flatten().float().to('cpu')
assert result.shape == baseline.shape
error = torch.abs(baseline - result)
denominator = torch.sum(torch.abs(baseline)).item()
eps = 0.0 if denominator > 0 else 1e-9
diff1 = torch.sum(error) / (denominator + eps)
return diff1.item()
def diff2(self, result: torch.Tensor, baseline: torch.Tensor):
result = result.flatten().float().to('cpu')
baseline = baseline.flatten().float().to('cpu')
error = torch.abs(baseline - result)
denominator = torch.sum(baseline**2).item()
eps = 0.0 if denominator > 0 else 1e-9
diff2 = torch.sqrt(torch.sum(error**2) / (denominator + eps))
return diff2.item()

View File

@@ -0,0 +1,17 @@
## BT_OPS测试脚本使用方式
```bash
# 测试所有测例
bash run_test.sh
```
```bash
# 测试单个测例
python3 test_测例名称.py
```
- 必须在Torch-MLU-Ops docker容器内运行。
- 测试脚本的命名规则为 `test_测例名称.py`
- 必须保证 Torch-MLU-Ops whl包正确安装。

View File

@@ -0,0 +1,147 @@
import torch
import torch_mlu
import unittest
import math
import torch_mlu_ops as ops
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
bias, softmax_scale, is_causal, world_size = 1):
q_list = q.chunk(world_size, dim=2)
k_list = k.chunk(world_size, dim=2)
v_list = v.chunk(world_size, dim=2)
smooth_list = smooth.chunk(world_size, dim=0)
quant_weight_list = quant_weight.chunk(world_size, dim=1)
quant_weight_list = [w.contiguous() for w in quant_weight_list]
output1 = torch.zeros(q.size(0) * q.size(1), q.size(2) * q.size(3), dtype=q.dtype).mlu()
for i in range(world_size):
attn_output = ops.flash_attention(q_list[i], k_list[i], v_list[i], None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth_list[i], None)
output1 += ops.smooth_quant_matmul(quant_input, input_scale,
quant_weight_list[i], weight_scale, q.dtype, bias if i == 0 else None)
attn_output = ops.flash_attention(q, k, v, None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
output2 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
return output1, output2
def tp_flash_attn_sq_mm(rank, *args):
world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, base_output = args
q_cpu, k_cpu, v_cpu, smooth_cpu = q.cpu(), k.cpu(), v.cpu(), smooth.cpu()
quant_weight_cpu, weight_scale_cpu, bias_cpu = quant_weight.cpu(), weight_scale.cpu(), bias.cpu()
base_output_cpu = base_output.cpu()
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
head_num_q = q.size(2)
head_num_kv = k.size(2)
seq = q.size(1)
assert head_num_q % world_size == 0
assert head_num_kv % world_size == 0
head_num_q_tp = head_num_q // world_size
head_num_kv_tp = head_num_kv // world_size
q = q_cpu.mlu()
k = k_cpu.mlu()
v = v_cpu.mlu()
smooth = smooth_cpu.mlu()
quant_weight = quant_weight_cpu.mlu()
weight_scale = weight_scale_cpu.mlu()
# Note: only tp0 add bias
bias = bias_cpu.mlu() if rank == 0 else None
q_list = q.chunk(world_size, dim=2)
k_list = k.chunk(world_size, dim=2)
v_list = v.chunk(world_size, dim=2)
smooth_list = smooth.chunk(world_size, dim=0)
quant_weight_list = quant_weight.chunk(world_size, dim=1)
quant_weight_list = [w.contiguous() for w in quant_weight_list]
# test pad mode
output_pad = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_list[rank], k_list[rank], v_list[rank],
None, None, None, None,
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
seq, softmax_scale, is_causal)
assertTensorsEqual(output_pad.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
# test pack mode
cu_seq_lens_q = torch.tensor([0, seq], dtype=torch.int32).mlu()
cu_seq_lens_k = torch.tensor([0, seq], dtype=torch.int32).mlu()
q_pack = q_list[rank].flatten(0, 1)
k_pack = k_list[rank].flatten(0, 1)
v_pack = v_list[rank].flatten(0, 1)
output_pack = ops.flash_attn_sq_mm_allreduce(cncl_comm, q_pack, k_pack, v_pack,
cu_seq_lens_q, cu_seq_lens_k, None, None,
smooth_list[rank], quant_weight_list[rank], weight_scale, bias, seq,
seq, softmax_scale, is_causal)
assertTensorsEqual(output_pack.cpu().float(), base_output_cpu.float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestFlashAttnSqMMAllreduce(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_flash_attn_sq_mm_split_seq due to ASan issues")
def test_flash_attn_sq_mm_split_seq(self):
batch, seq, head_num_q, head_num_kv, head_size, is_causal, block_seq = 16, 1024, 8, 1, 128, True, 4
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
hidden_size = head_num_q * head_size
softmax_scale = 1 / math.sqrt(head_size)
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
bias = torch.zeros(hidden_size, dtype=dtype).mlu()
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
q = qkv[:, :, :head_num_q, :]
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
v = qkv[:, :, head_num_q+head_num_kv:, :]
attn_output = ops.flash_attention(q, k, v, None, None, None,
None, None, q.size(1), k.size(1), softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
output1 = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
output_list = []
block_size = seq // block_seq
for i in range(block_seq):
start = i * block_size
end = seq if i == block_seq - 1 else (i + 1) * block_size
attn_output = ops.flash_attention(q[:,start:end], k[:,:end], v[:,:end], None, None, None,
None, None, end - start, end, softmax_scale, is_causal).flatten(-2, -1).flatten(0, 1)
quant_input, input_scale = ops.per_token_smooth_quantize(attn_output, smooth, None)
out = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, q.dtype, bias)
output_list.append(out)
output2 = torch.cat(output_list, dim=0)
output2 = output2.reshape(block_seq, batch, block_size, hidden_size).transpose(0, 1).reshape(batch*seq, hidden_size)
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.0045, use_MSE=True, use_RAE=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_flash_attn_sq_mm_allreduce(self):
world_size = min(torch_mlu.mlu.device_count(), 8)
batch, seq, head_num_q, head_num_kv, head_size, is_causal = 1, 8192, 64, 8, 128, True
dtype = torch.bfloat16 if torch_mlu.mlu.is_bf16_supported() else torch.half
for i in range(1):
# seq = random.randint(1, 32768)
# is_causal = bool(random.randint(0, 1))
print("=============test[{}]: seq = {}, causal = {}===============".format(i, seq, is_causal))
hidden_size = head_num_q * head_size
softmax_scale = 1 / math.sqrt(head_size)
qkv = torch.randn(batch, seq, head_num_q + 2 * head_num_kv, head_size, dtype=dtype).mlu()
smooth = torch.zeros(hidden_size, dtype=torch.float).mlu() + 1.0
bias = torch.randn(hidden_size, dtype=dtype).mlu()
weight = torch.randn(hidden_size, hidden_size, dtype=dtype).mlu()
quant_weight, weight_scale = QuantByRow(weight / smooth, 8)
q = qkv[:, :, :head_num_q, :]
k = qkv[:, :, head_num_q:head_num_q+head_num_kv, :]
v = qkv[:, :, head_num_q+head_num_kv:, :]
output1, output2 = flash_attn_sq_mm(q, k, v, smooth, quant_weight, weight_scale,
bias, softmax_scale, is_causal, world_size)
args = world_size, q, k, v, smooth, quant_weight, weight_scale, bias, softmax_scale, is_causal, output1
mp.spawn(tp_flash_attn_sq_mm, args, nprocs=world_size, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestFlashAttnSqMMAllreduce))

View File

@@ -0,0 +1,109 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from itertools import product
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def mm_split_k(rank, *args):
torch.manual_seed(0)
mat_m, mat_n, mat_k, has_bias, has_res, dtype, world_size, block_m = args
assert mat_k % world_size == 0, f"mat_k{mat_k} must be divisible by tp{world_size}"
block_k = mat_k // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
alpha = 0.625
beta = 1.0 if has_res else 0.
input = torch.randn((mat_m, mat_k), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu') if has_res else None
weight = torch.randn((mat_n, mat_k), dtype=dtype, device="mlu")
bias = torch.randn(mat_n, dtype=dtype, device="mlu") if has_bias else None
pt_output = torch.matmul(input, weight.permute(1, 0))
if has_bias:
pt_output = pt_output + bias
pt_output *= alpha
if has_res:
pt_output += beta * residual
if has_bias:
bias *= alpha
input = input[..., start_k:end_k].contiguous()
weight = weight[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
residual = residual if (residual is not None and rank == 0) else None
beta = beta if (residual is not None and rank == 0) else 0.
output = ops.matmul_allreduce(cncl_comm, input, weight, bias, residual, alpha, beta, block_m)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestMatMulAllReduceOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce due to ASan issues")
def test_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [1680]
has_res_list = [False, True]
has_bias_list = [False, True]
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
block_m = 4
args = product(mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, dtype_list)
for mat_m, mat_n, mat_k, has_bias, has_res, dtype in args:
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp={}, block_m={}, testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m), flush=True)
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, device_n, block_m]
mp.spawn(mm_split_k, param, nprocs=device_n, join=True)
@unittest.skip("not test")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_matmul_allreduce_random due to ASan issues")
def test_matmul_allreduce_random(self):
import random
random.seed(0)
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for i in range(10):
tp_num = random.randint(1, device_n)
mat_m = random.randint(1, 4096)
mat_n = random.choice([512, 1024, 2048, 4096])
k_start = (1024 // tp_num) * tp_num
mat_k = random.randrange(k_start, 10240, tp_num)
has_res = random.choice([False, True])
has_bias = random.choice([False, True])
dtype = random.choice(dtype_list)
block_m = random.randint(1, 10)
assert mat_k % tp_num == 0, f"mat_k{mat_k} must be divisible by tp_num{tp_num}"
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m), flush=True)
param = [mat_m, mat_n, mat_k, has_bias, has_res, dtype, tp_num, block_m]
mp.spawn(mm_split_k, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestMatMulAllReduceOp))

View File

@@ -0,0 +1,305 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.multiprocessing as mp
import math
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def moe_split_inner_size(rank, *args):
torch.manual_seed(0)
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
block_k = inner_size // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
scale_s = 0.01 # avoid the occurrence of inf
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=dtype) * scale_s
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=dtype)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=dtype) * scale_s
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=dtype)
weight1 = weight1.view(expert_num, -1, inner_size, hidden_size)
weight1 = weight1[:, :, start_k:end_k]
weight1 = weight1.reshape(expert_num, -1, hidden_size).contiguous()
weight2 = weight2[..., start_k:end_k].contiguous()
if bias1 is not None:
bias1 = bias1.view(expert_num, -1, inner_size)
bias1 = bias1[..., start_k:end_k]
bias1 = bias1.reshape(expert_num, -1).contiguous()
residual = residual if (residual is not None and rank == 0) else None
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode]
output = ops.fused_moe(*param)
all_reduce(output, ReduceOp.SUM, group=pg)
param = [hidden_states, router_logit, weight1, weight2, bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
output1 = ops.fused_moe(*param)
new_inner_size = weight2.shape[-1]
block_e = 4096 // new_inner_size
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
param = [hidden_states, router_logit, weight1, weight2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual,
None, None, None, None, topk, renormalize, gated, act_mode, 0, block_n, cncl_comm]
output2 = ops.fused_moe(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
def sq_moe_split_inner_size(rank, *args):
torch.manual_seed(0)
batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, world_size = args
assert inner_size % world_size == 0, f"inner_size{inner_size} must be divisible by tp_num{world_size}"
block_k = inner_size // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
scale_s = 0.1 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=dtype)
bias1, bias2 = None, None
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=dtype)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=dtype)
weight2 = torch.normal(0, 0.01, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=dtype)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
quant_w1 = quant_w1.view(expert_num, -1, inner_size, hidden_size)
quant_w1 = quant_w1[:, :, start_k:end_k]
quant_w1 = quant_w1.reshape(expert_num, -1, hidden_size).contiguous()
quant_w2 = quant_w2[..., start_k:end_k].contiguous()
if bias1 is not None:
bias1 = bias1.view(expert_num, -1, inner_size)
bias1 = bias1[..., start_k:end_k]
bias1 = bias1.reshape(expert_num, -1).contiguous()
if w1_scale is not None:
w1_scale = w1_scale.view(expert_num, -1, inner_size)
w1_scale = w1_scale[..., start_k:end_k]
w1_scale = w1_scale.reshape(expert_num, -1).contiguous()
if act_smooth is not None:
act_smooth = act_smooth[..., start_k:end_k].contiguous()
residual = residual if (residual is not None and rank == 0) else None
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode]
output = ops.fused_moe(*param)
all_reduce(output, ReduceOp.SUM, group=pg)
param = [hidden_states, router_logit, quant_w1, quant_w2, bias1, bias2, residual,
input_smooth, act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0,
block_n, cncl_comm]
output1 = ops.fused_moe(*param)
new_inner_size = quant_w2.shape[-1]
block_e = 4096 // new_inner_size
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
param = [hidden_states, router_logit, quant_w1,
quant_w2.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, input_smooth, act_smooth, w1_scale, w2_scale, topk,
renormalize, gated, act_mode, 0, block_n, cncl_comm]
output2 = ops.fused_moe(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), output1.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
if block_e * new_inner_size == 4096 and expert_num % block_e == 0:
assertTensorsEqual(output1.cpu().float(), output2.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestFusedMOEAllReduceOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_fused_moe_allreduce due to ASan issues")
def test_single_fused_moe_allreduce(self):
print("test_single_fused_moe_allreduce")
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_n = 2
batch, seq, hidden_size, inner_size = 1, 1024, 8192, 8192
expert_num, topk, gated, renormalize, act_mode, dtype = 8, 2, True, True, 'silu', torch.float16
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
mp.spawn(moe_split_inner_size, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_fused_moe_allreduce due to ASan issues")
def test_random_fused_moe_allreduce(self):
print("test_random_fused_moe_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
while(len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 1024, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
k_start = (512 // tp_num) * tp_num
inner_size = random.randrange(k_start, 1024, tp_num * 2)
expert_num = random.randint(1, 32)
topk = random.randint(1,expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice(dtype_list)
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(moe_split_inner_size, param, nprocs=tp_num, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_sq_fused_moe_allreduce due to ASan issues")
def test_sq_fused_moe_allreduce(self):
print("test_sq_fused_moe_allreduce")
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_n = 2
batch, seq, hidden_size, inner_size = 5, 9, 8192, 8192
expert_num, topk, gated, renormalize, act_mode, dtype = 20, 16, False, True, 'gelu', torch.float16
print(f"sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {device_n}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, device_n]
mp.spawn(sq_moe_split_inner_size, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_allreduce due to ASan issues")
def test_random_sq_fused_moe_allreduce(self):
print("test_random_sq_fused_moe_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
while(len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(256, 512, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
k_start = (512 // tp_num) * tp_num
inner_size = random.randrange(k_start, 1024, tp_num * 2)
expert_num = random.randint(1, 32)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice([ torch.float16])
if torch_mlu.mlu.get_device_name() == 'MLU370':
dtype = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_fused_moe_with_4D_w2_allreduce due to ASan issues")
def test_random_sq_fused_moe_with_4D_w2_allreduce(self):
print("test_random_sq_fused_moe_with_4D_w2_allreduce")
import random
random.seed(0)
device_n = min(torch_mlu.mlu.device_count(), 8)
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 10):
block_n = random.randint(-5, 5)
tp_num = random.randint(1, device_n)
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 4096, 2)
if block_n != 0:
if hidden_size // abs(block_n) < 256:
continue
hidden_size = (hidden_size // block_n) * block_n
else:
hidden_size = 1024 * random.randint(1, 10)
inner_size_per_tp = random.choice([256, 512])
inner_size = inner_size_per_tp * tp_num
expert_num_base = random.randint(1, 4)
expert_num_factor = 4096 // inner_size_per_tp
expert_num = expert_num_base * expert_num_factor
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
dtype = random.choice([ torch.float16])
if torch_mlu.mlu.get_device_name() == 'MLU370':
dtype = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, tp_num, block_n)
if case in case_list:
continue
case_list.add(case)
print(f"random sq bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, dtype: {dtype}, act_mode: {act_mode}, \
tp_num: {tp_num}, block_n: {block_n}, testing...", flush=True)
param = [batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, act_mode, dtype, block_n, tp_num]
mp.spawn(sq_moe_split_inner_size, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestFusedMOEAllReduceOp))

View File

@@ -0,0 +1,157 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.multiprocessing as mp
import sys
import os
work_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(work_dir))
from common_utils import *
def compute_weight_only_scale(weight, quant_bit):
int_max = float(2 ** (quant_bit - 1) - 1)
weight_max = torch.max(torch.abs(weight), axis=1, keepdims=True)
weight_scale = torch.div(int_max, weight_max[0])
weight_int = torch.mul(weight, weight_scale)
weight_int = weight_int.type(torch.int8)
weight_scale_recip = torch.div(weight_max[0], int_max).type(torch.float).squeeze()
return weight_int, weight_scale_recip
def quant_mm_split_k(rank, *args):
torch.manual_seed(0)
M, N, K, has_bias, has_res, dtype, block_m, world_size = args
assert K % world_size == 0
block_k = K // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
a = torch.randn(M, K, device="mlu", dtype=dtype)
b = torch.randn(N, K, device="mlu", dtype=dtype)
b_int, gemm_output_scale = compute_weight_only_scale(b, 8)
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
torch_quant_matmul = QuantMatmul(b_int, None, None, None, None, gemm_output_scale, dtype)
pt_output = torch_quant_matmul(a).detach()
a = a[..., start_k:end_k].contiguous()
b_int = b_int[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
c = c if (c is not None and rank == 0) else None
param = [cncl_comm, a, None, None, b_int, None, None, bias, c, None,
None, gemm_output_scale, None, "half", "weight_only", "quantize_none",
"quantize_per_channel", 8, 1.0, 1.0, False, True, block_m]
# beta = beta if (c is not None and rank == 0) else 0.
output = ops._ops.quant_matmul_allreduce(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
def sq_mm_split_k(rank, *args):
torch.manual_seed(0)
M, N, K, has_bias, has_res, dtype, world_size, block_m = args
assert K % world_size == 0
block_k = K // world_size
start_k = rank * block_k
end_k = start_k + block_k
setup(rank, world_size)
pg = get_default_group()
cncl_comm = pg._get_backend(torch.device("mlu")).get_cncl_comm(rank)
a = torch.randint(-10, 10, (M, K), dtype=torch.int8).mlu()
b = torch.randint(-10, 10, (N, K), dtype=torch.int8).mlu()
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
torch_quant_matmul = QuantMatmul(b, bias, c, a_scale, b_scale, None, dtype)
pt_output = torch_quant_matmul(a).detach()
a = a[..., start_k:end_k].contiguous()
b = b[..., start_k:end_k].contiguous()
bias = bias if (bias is not None and rank == 0) else None
c = c if (c is not None and rank == 0) else None
# beta = beta if (c is not None and rank == 0) else 0.
param = [cncl_comm, a, a_scale, b, b_scale, dtype, bias, c, 1.0, 1.0, block_m]
output = ops.smooth_quant_matmul_allreduce(*param)
if rank == 0:
assertTensorsEqual(output.cpu().float(), pt_output.cpu().float(), 0.006, use_MSE=True, use_RAE=True)
cleanup()
class TestGptQuantMatmulOp(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
# weight only, no bias, no residual, no activation
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_weight_only_matmul_allreduce due to ASan issues")
def test_single_weight_only_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_m = 8
M, K, N = 32, 256, 128
has_bias = False
has_res = False
dtype = torch.half
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
param = [M, N, K, has_bias, has_res, dtype, block_m, device_n]
mp.spawn(quant_mm_split_k, param, nprocs=device_n, join=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_single_sq_matmul_allreduce due to ASan issues")
def test_single_sq_matmul_allreduce(self):
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
block_m = 1
M, K, N = 2598, 1024, 1024
has_bias = False
has_res = False
dtype = torch.half
print("m={}, n={}, k={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, device_n, block_m), flush=True)
assert K % device_n == 0, f"K{K} must be divisible by tp_num{device_n}"
param = [M, N, K, has_bias, has_res, dtype, device_n, block_m]
mp.spawn(sq_mm_split_k, param, nprocs=device_n, join=True)
@unittest.skip("not test")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_random_sq_mm_allreduce due to ASan issues")
def test_random_sq_mm_allreduce(self):
import random
random.seed(0)
device_n = torch_mlu.mlu.device_count()
if device_n < 2:
print(f"device count is {device_n}, can not test communication between devices.", flush=True)
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
for i in range(10):
tp_num = random.randint(1, device_n)
M = random.randint(1, 4096)
N = random.choice([512, 1024, 2048, 4096])
k_start = (1024 // tp_num) * tp_num
K = random.randrange(k_start, 2048, tp_num)
has_res = random.choice([False, True])
has_bias = random.choice([False, True])
dtype = random.choice(dtype_list)
block_m = random.randint(1, 10)
print("M={}, N={}, K={}, has_bias={}, has_res={}, dtype={}, tp_num={}, block_m={}, testing...".format(
M, N, K, has_bias, has_res, dtype, tp_num, block_m), flush=True)
param = [M, N, K, has_bias, has_res, dtype, tp_num, block_m]
mp.spawn(sq_mm_split_k, param, nprocs=tp_num, join=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestGptQuantMatmulOp))

View File

@@ -0,0 +1,734 @@
import sys
sys_args = sys.argv
sys.argv = [sys_args.pop(0)] # prevent unittest printing help info
import os
import torch
import torch_mlu
from torch.testing._internal.common_utils import TestCase
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from typing import List, Tuple, Optional
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group as get_default_group, all_reduce, ReduceOp
import torch.testing._internal.optests as optests
import random
import argparse
from abc import abstractmethod, ABC
import unittest
import torch_mlu_ops as tmo
import os
act_mode_dict = {"relu": torch.nn.functional.relu,
"gelu": torch.nn.functional.gelu,
"silu": torch.nn.functional.silu}
class BtTestCase(TestCase, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
os.environ['TORCH_ALLOW_TF32_CNMATMUL_OVERRIDE'] = '0'
@abstractmethod
def op_impl_base(self, *args):
pass
@abstractmethod
def test_inductor(self):
pass
def base_opcheck(self, interface_overload, args):
target_check = ["test_schema", "test_autograd_registration"]
if torch.__version__ >= '2.3.0':
target_check.append("test_faketensor")
target_status = {key: "SUCCESS" for key in target_check}
result = optests.opcheck(interface_overload, args, test_utils=target_check)
self.assertEqual(result, target_status,)
def assertException(self, error_msg, func, *args, **kwinputs):
try:
func(*args, **kwinputs)
self.assertTrue(False)
except Exception as e:
if error_msg:
self.assertTrue(error_msg == str(e))
else:
self.assertTrue(True)
def assertTensorsEqual(self,
a,
b,
prec=None,
message='',
allow_inf=False,
use_MSE=False,
use_RAE=False,
use_RMA=False):
'''unittest.TestCase'''
if a.dtype == torch.bool:
a = a.float()
if b.dtype == torch.bool:
b = b.float()
epsilon = 1.0 / 16384
self.assertEqual(a.size(), b.size(), message)
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
if a.numel() > 0:
# check that NaNs are in the same locations
nan_mask = a != a
self.assertTrue(torch.equal(nan_mask, b != b), message)
diff = a - b
diff[nan_mask] = 0
a = a.clone()
b = b.clone()
a[nan_mask] = 0
b[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = (a == float("inf")) | (a == float("-inf"))
self.assertTrue(torch.equal(inf_mask,
(b == float("inf")) | (b == float("-inf"))),
message)
diff[inf_mask] = 0
a[inf_mask] = 0
b[inf_mask] = 0
# TODO: implement abs on CharTensor
if diff.is_signed() and 'CharTensor' not in diff.type():
diff = diff.abs()
if use_MSE:
diff = diff.abs().pow(2).sum()
a_pow_sum = a.pow(2).sum()
if diff <= (2 * epsilon) * (2 * epsilon):
diff = 0.0
if a_pow_sum <= epsilon:
a_pow_sum = a_pow_sum + epsilon
diff = torch.div(diff, (a_pow_sum * 1.0))
self.assertLessEqual(diff.sqrt(), prec, message)
elif use_RAE:
diff = diff.abs().sum()
a_sum = a.abs().sum()
if a_sum == 0:
self.assertEqual(a, b, message)
else:
diff = torch.div(diff, a_sum)
self.assertLessEqual(diff, prec, message)
elif use_RMA:
a_mean = a.abs().mean()
b_mean = b.abs().mean()
if a_mean == 0:
self.assertEqual(a, b, message)
else:
diff = torch.div((a_mean - b_mean).abs(), a_mean)
self.assertLessEqual(diff, prec, message)
else:
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
def run_unittest(case) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('-k', nargs='+', type=str, default="", help='specify case to run')
args = parser.parse_args(sys_args)
if args.k != "":
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromNames(args.k, case))
else:
ret = unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(case))
return not ret.wasSuccessful()
class TMOTimer:
def __init__(self, repeat: int = 1):
self.repeat = repeat
def __enter__(self):
self.notify_start = torch.mlu.Event(enable_timing=True)
self.notify_end = torch.mlu.Event(enable_timing=True)
self.notify_start.record()
def __exit__(self, exc_type, exc_value, traceback):
self.notify_end.record()
self.notify_end.synchronize()
total_hardware_time = self.notify_start.hardware_time(self.notify_end)
self.average_hardware_time = total_hardware_time / self.repeat
def QuantByRow(input: torch.Tensor, quant_bit: int, group_num: int=1):
input_shape = input.shape
if input.dim() > 2:
input = input.view(-1, input_shape[-1])
if input.dim() == 1:
input = input.unsqueeze(0)
assert input.dim() == 2, "input must be 2-D tensor."
assert quant_bit == 4 or quant_bit == 8, "quant_bit must be 4 or 8."
assert group_num >= 1, "group_num >= 1."
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
group_size = input.size(-1) // group_num
input_v = input.view(input.size(0), group_num, group_size) if group_num > 1 else input
max, _ = input_v.abs().max(dim=-1, keepdim=True)
scale = max.to(torch.float) / int_max
quant_input = (input_v / scale).round().clamp(int_min, int_max).to(torch.int8).view(input.size())
return quant_input.view(input_shape), scale.squeeze(-1)
def QuantByTensor(input: torch.Tensor, quant_bit: int):
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
input_max = torch.max(torch.abs(input))
input_scale = int_max / input_max
input_int = torch.mul(input, input_scale).round().clamp(int_min, int_max).to(torch.int8)
return input_int, input_scale
def PairlyPackInt8(input):
assert input.dtype == torch.int8, "dtype of input must be int8."
assert input.dim() == 2 or input.dim() == 3, "input must be 2-D or 3-D tensor."
assert input.size(-1) % 2 == 0, "size(-1) of input must be even."
input_shape = list(input.shape)
input_flat = input.flatten()
d0 = input_flat[0::2].to(torch.uint8)
d1 = input_flat[1::2].to(torch.uint8)
dp = (d1 << 4) + (d0 & 0x0F)
input_shape[-1] = input_shape[-1] // 2
return dp.to(torch.int8).reshape(input_shape)
def UnpackInt4(input):
assert input.dtype == torch.int8, "dtype of input must be int8."
input_flat = input.flatten()
n = input_flat.size(0)
output = torch.zeros(n * 2, dtype=torch.int8, device=input.device)
high = input_flat >> 4
low = input_flat << 4
low = low >> 4
output[0::2] = low
output[1::2] = high
return output
def smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias=None):
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
assert a_scale.dim() == 1, "a_scale.dim() == 1"
assert a.size(0) == a_scale.size(0), "a.size(0) == a_scale.size(0)"
assert b.size(0) == b_scale.size(-1), "b.size(0) == b_scale.size(-1)"
m = a.size(0)
n = b.size(0)
a_k = a.size(1)
b_k = b.size(1)
if b_scale.dim() == 1:
b_scale = b_scale.unsqueeze(0)
quant_group = b_scale.size(0)
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
if a_k == b_k * 2:
b = UnpackInt4(b)
b = b.view(n, quant_group, -1).transpose(0, 1).contiguous()
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
for i in range(quant_group):
scale_mn = torch.matmul(a_scale.unsqueeze(1), b_scale[i].unsqueeze(0)) # (m, 1) x (1, n) = (m, n)
out += torch.einsum('mk,nk->mn', a[i].to(torch.float), b[i].to(torch.float)) * scale_mn
# out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
out = out.to(out_dtype)
if bias is not None:
out += bias
return out
def smooth_quant_matmul_w4w8_mixed(a, a_scale, b, b_scale, out_dtype, bias=None, quant_flag=None):
m = a.shape[0]
k = a.shape[1]
quant_group = b_scale.shape[0]
group_wise = k // quant_group
n = b_scale.shape[1]
b = b.view(n, -1)
a = a.view(m, quant_group, -1).transpose(0, 1).contiguous()
new_b = []
start = 0
end = 0
for i in range(quant_group):
if quant_flag[i] == 4:
end += group_wise // 2
new_b.append(UnpackInt4(b[:, start:end]).view(n, -1))
else:
end += group_wise
new_b.append((b[:, start:end]))
start = end
new_b = torch.cat(new_b, 1)
b = new_b.view(n, quant_group, -1).transpose(0, 1).contiguous()
out = torch.zeros(m, n, dtype=torch.float, device=a.device)
for i in range(quant_group):
out += smooth_quant_matmul(a[i], a_scale, b[i], b_scale[i], out_dtype)
out = out.to(out_dtype)
if bias is not None:
out += bias
return out
def weight_only_quant_matmul(a, b, scale, bias=None):
assert a.dim() == 2 and b.dim() == 2, "a.dim() == 2 and b.dim() == 2"
assert scale.dim() == 1 or scale.dim() == 2, "scale.dim() == 1 or scale.dim() == 2"
assert b.size(0) == scale.size(0), "b.size(0) == b_scale.size(0)"
assert a.size(1) == b.size(1), "a.size(1) == b.size(1)"
if scale.dim() == 2:
group_size = b.size(1) // scale.size(1)
scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape)
else:
scale_bd = scale.unsqueeze(-1)
b1 = b * scale_bd
out = torch.einsum('mk,nk->mn', a.to(torch.float), b1.to(torch.float)).to(a.dtype)
if bias is not None:
out += bias
return out
def single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale,
v_cache_quant_scale, alibi_slopes, window_size_left, window_size_right, softmax_scale, return_lse):
q = q.float()
k_cache = k_cache.float()
v_cache = v_cache.float()
def masked_attention(query, key, value, alibi_slope, context_len, window_size_left, window_size_right, qk_scale) -> torch.Tensor:
# (num_heads, seq_q, seq_k)
qk = torch.einsum('qhd,hkd->hqk', query, key)
qk = qk * qk_scale
if alibi_slope is not None:
alibi_dist = torch.arange(0, context_len, dtype=torch.float32).mlu()
alibi = alibi_slope[:, None] * alibi_dist
qk = qk + alibi[:, None, :]
_, seq_q, seq_k = qk.size()
if seq_q > 1: #causal mask
ml = torch.zeros((seq_q, seq_k - seq_q), dtype=qk.dtype).mlu()
ones = torch.ones((seq_q, seq_q), dtype=qk.dtype).mlu() * -torch.inf
mr = torch.triu(ones, diagonal=1)
mask = torch.cat((ml, mr), dim=-1)
qk = qk + mask
if window_size_left != -1 or window_size_right != -1:
mask_w = torch.full((seq_q, seq_k), -torch.inf, dtype=torch.float, device="mlu")
for qi in range(seq_q):
left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0
right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k
mask_w[qi, left:right] = 0
qk += mask_w
attention = torch.softmax(qk, dim = -1, dtype=qk.dtype)
qkv = torch.einsum('hqk,hkd->qhd', attention, value)
return qkv, qk
if k_cache_quant_scale is not None and v_cache_quant_scale is not None:
if k_cache_quant_scale.dim() == 2: # per_channel: [kv_head_num, head_size]
k_cache_quant_scale = k_cache_quant_scale.reshape(1, k_cache_quant_scale.shape[0], 1, k_cache_quant_scale.shape[1])
v_cache_quant_scale = v_cache_quant_scale.reshape(1, v_cache_quant_scale.shape[0], 1, v_cache_quant_scale.shape[1])
elif k_cache_quant_scale.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
k_cache_quant_scale = k_cache_quant_scale.reshape(*k_cache_quant_scale.shape, 1)
v_cache_quant_scale = v_cache_quant_scale.reshape(*v_cache_quant_scale.shape, 1)
k_cache *= k_cache_quant_scale
v_cache *= v_cache_quant_scale
bs, seq_q, num_heads, head_size = q.size()
head_size_v = v_cache.size(-1)
num_blocks, num_kv_heads, block_size, _ = k_cache.size()
output = torch.zeros((bs, seq_q, num_heads, head_size_v), dtype=torch.float16)
lse = torch.zeros((bs, num_heads, seq_q), dtype=torch.float)
assert (num_heads % num_kv_heads == 0)
head_repeats = num_heads // num_kv_heads
for bs_id in range(bs):
q_bs = q[bs_id]
context_len = int(context_lens[bs_id])
if context_len == 0:
output[bs_id] = torch.zeros((seq_q, num_heads, head_size_v), device = q.device, dtype=output.dtype)
lse[bs_id] = lse[bs_id].fill_(-float('inf'))
else :
block_table = block_tables[bs_id]
table_end = (context_len + block_size - 1) // block_size
block_ids = block_table[0 : table_end]
keys, values = k_cache[block_ids], v_cache[block_ids]
keys = torch.repeat_interleave(keys, head_repeats, dim=1)
keys = keys.transpose(1, 0).contiguous().view(num_heads, -1, head_size)
keys = keys[:, 0:context_len, :]
values = torch.repeat_interleave(values, head_repeats, dim=1)
values = values.transpose(1, 0).contiguous().view(num_heads, -1, head_size_v)
values = values[:, 0:context_len, :]
alibi_slope = alibi_slopes[bs_id] if alibi_slopes is not None else None
qkv, qk= masked_attention(q_bs, keys, values, alibi_slope, context_len, window_size_left, window_size_right, softmax_scale)
output[bs_id] = qkv
lse[bs_id] = torch.logsumexp(qk, dim = -1)
return (output, lse) if return_lse else output
def update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs):
# only pad
is_pack = out.dim() == 3
new_out, new_lse = out.clone(), lse.clone()
batch, max_seq_len, block_seq_len = lse.shape[0], lse.shape[-1], block_lse.shape[-1]
lse_bsh = lse.transpose(-2, -1).unsqueeze(dim=-1)
new_lse_bsh = new_lse.transpose(-2, -1).unsqueeze(dim=-1)
block_lse_bsh = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
if not is_pack:
for i in range(batch):
out_seq_offset = 0 if seq_offsets is None else seq_offsets[i]
out_i = out[i, out_seq_offset : out_seq_offset + block_seq_len]
lse_i = lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len]
block_out_i = block_out[i, :]
block_lse_i = block_lse_bsh[i, :]
new_out[i, out_seq_offset : out_seq_offset + block_seq_len] = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
new_lse_bsh[i, out_seq_offset : out_seq_offset + block_seq_len] = (lse_i - F.logsigmoid(lse_i - block_lse_i))
else:
for i in range(batch):
block_i_begin = block_cu_seqs[i]
block_i_end = block_cu_seqs[i + 1]
block_i_lens = block_i_end - block_i_begin
out_i_begin = cu_seqs[i]
out_seq_offset = seq_offsets[i]
block_out_i = block_out[block_i_begin : block_i_end]
block_lse_i = block_lse_bsh[i, 0 : block_i_lens]
out_i = out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens]
lse_i = lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens]
new_out_i = out_i - F.sigmoid(block_lse_i - lse_i) * (out_i - block_out_i)
new_lse_i = (lse_i - F.logsigmoid(lse_i - block_lse_i))
new_out[out_i_begin + out_seq_offset: out_i_begin + out_seq_offset + block_i_lens] = new_out_i
new_lse_bsh[i, out_seq_offset: out_seq_offset + block_i_lens] = new_lse_i
return (new_out, new_lse_bsh.squeeze(dim=-1).transpose(-2, -1))
class QuantMatmul(torch.nn.Module):
def __init__(self, weight, bias, residual, input_scale, weight_scale, gemm_output_scale, dtype,
alpha:float = 1.0, beta:float = 1.0, act_mode:str = 'none') -> None:
super().__init__()
self.dtype = dtype
self.weight = Parameter(weight.type(dtype))
self.input_scale = input_scale
self.weight_scale = weight_scale
self.gemm_output_scale = gemm_output_scale
if bias is not None:
self.bias = Parameter(bias)
else:
self.bias = None
if residual is not None:
self.residual = Parameter(residual)
else:
self.residual = None
self.alpha = alpha
self.beta = beta
if act_mode == 'none':
self.act = None
else:
self.act = act_mode_dict[act_mode]
# d = (a * b + bias) * alpha + c * beta
# output = (input * weight + bias) * alpha + residual * beta
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input.type(self.dtype), self.weight, self.bias)
if self.input_scale is not None:
i_scale = self.input_scale.expand(self.weight_scale.shape[0], -1).transpose(0, 1)
output = torch.mul(output, i_scale)
if self.weight_scale is not None:
output = torch.mul(output, self.weight_scale)
if self.gemm_output_scale is not None:
output = torch.mul(output, self.gemm_output_scale)
output = torch.mul(output, self.alpha)
if self.residual is not None:
residual = torch.mul(self.residual, self.beta)
output = torch.add(output, residual)
if self.act is not None:
output = self.act(output)
return output
# for multiprocessing
def assertTensorsEqual( a,
b,
prec=None,
message='',
allow_inf=False,
use_MSE=False,
use_RAE=False,
use_RMA=False):
tc = TestCase()
if a.dtype == torch.bool:
a = a.float()
if b.dtype == torch.bool:
b = b.float()
epsilon = 1.0 / 16384
tc.assertEqual(a.size(), b.size(), message)
assert (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)), "a and b are need be torch tensor."
if a.numel() > 0:
# check that NaNs are in the same locations
nan_mask = a != a
tc.assertTrue(torch.equal(nan_mask, b != b), message)
diff = a - b
diff[nan_mask] = 0
a = a.clone()
b = b.clone()
a[nan_mask] = 0
b[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = (a == float("inf")) | (a == float("-inf"))
tc.assertTrue(torch.equal(inf_mask,
(b == float("inf")) | (b == float("-inf"))),
message)
diff[inf_mask] = 0
a[inf_mask] = 0
b[inf_mask] = 0
# TODO: implement abs on CharTensor
if diff.is_signed() and 'CharTensor' not in diff.type():
diff = diff.abs()
if use_MSE:
diff = diff.abs().pow(2).sum()
a_pow_sum = a.pow(2).sum()
if diff <= (2 * epsilon) * (2 * epsilon):
diff = 0.0
if a_pow_sum <= epsilon:
a_pow_sum = a_pow_sum + epsilon
diff = torch.div(diff, (a_pow_sum * 1.0))
tc.assertLessEqual(diff.sqrt(), prec, message)
elif use_RAE:
diff = diff.abs().sum()
a_sum = a.abs().sum()
if a_sum == 0:
tc.assertEqual(a, b, message)
else:
diff = torch.div(diff, a_sum)
tc.assertLessEqual(diff, prec, message)
elif use_RMA:
a_mean = a.abs().mean()
b_mean = b.abs().mean()
if a_mean == 0:
tc.assertEqual(a, b, message)
else:
diff = torch.div((a_mean - b_mean).abs(), a_mean)
tc.assertLessEqual(diff, prec, message)
else:
max_err = diff.max()
tc.assertLessEqual(max_err, prec, message)
def setup(rank, world_size, backend='cncl'):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '3458'
dist.init_process_group(backend, rank=rank, world_size=world_size)
torch_mlu.mlu.set_device(rank)
def cleanup():
dist.barrier()
dist.destroy_process_group()
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]
def generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype,
quant_mode=False, offline=False, invalid_batch_size=0):
q_heads = 1
total_heads = q_heads + num_heads * 2
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 2,
dtype=torch.int32, device='mlu')
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
cache_bs_id = random.sample([*range(0, batch_size)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
if invalid_batch_size > 0:
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch_size)] = -1
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
context_seq_offsets = None
else:
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cu_context_lens = context_lens
context = context.to(dtype)
key = context[..., q_heads:q_heads + num_heads, :]
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
cache_scale = None
if quant_mode:
cache = (cache - 0.5) * 256
cache = cache.to(torch.int8)
if offline:
cache_scale = torch.randn((2, cache.shape[2], cache.shape[4]), dtype=torch.float, device='mlu')
else:
cache_scale = torch.randn((2, max_bs, num_heads, cache_memory_len), dtype=torch.float, device='mlu')
else:
cache = cache.to(dtype)
block_size = 16 if "MLU3" not in torch.mlu.get_device_name() else max_context_len
min_blocks = (total_seqlen + block_size - 1) // block_size
num_blocks = min(min_blocks + 10, 2 * min_blocks)
num_slots = num_blocks * block_size
slot_mapping = random.sample(range(num_slots), total_seqlen.item())
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1
return [key, value, cache[0], cache[1], cu_context_lens, max_context_len,
packed > 0, context_seq_offsets, cache_bs_id, cache_seq_offsets,
cache_scale, slot_mapping]
def fused_moe(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
input_smooth: Optional[torch.Tensor],
act_smooth: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
topk: int,
renormalized: bool,
gated: bool,
act_mode: str,
start_expert_id: int = 0,
block_n: int = 0,
cncl_comm: int = 0,
w1_quant_flag: Optional[List] = None,
w2_quant_flag: Optional[List] = None):
dtype = hidden_states.dtype
ori_input_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
tokens = hidden_states.size(0)
gating_output = gating_output.reshape(-1, gating_output.size(-1))
residual = residual.reshape(-1, residual.size(-1)) if residual is not None else None
expert_num = gating_output.size(-1)
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
per_token_sq = False
# check quant
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
if all(x is not None for x in check_list):
per_token_sq = True
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
"and absent at the same time.")
# softmax_topk
reduce_weight, expert_id = tmo.moe_softmax_topk(gating_output, topk, renormalized)
# gen_idx
expand_idx, combine_idx, token_count, cusum_token_count = tmo.moe_gen_idx(expert_id, expert_num)
if per_token_sq:
if torch.mlu.get_device_name() == 'MLU370':
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
quant_input, input_scale = tmo.moe_quantize(expand_hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size])
else:
quant_input, input_scale = tmo.moe_quantize(hidden_states,
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
cusum_token_count[start_expert_id].unsqueeze(0))
else:
expand_hidden_states = tmo.moe_expand_input(hidden_states, expand_idx,
cusum_token_count, start_expert_id, expert_size)
# group gemm
if per_token_sq:
gemm1_out = tmo.smooth_quant_group_gemm(quant_input,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None,
input_scale, w1_scale, dtype, tokens, quant_flag = w1_quant_flag)
else:
gemm1_out = tmo.group_gemm(expand_hidden_states,
w1,
token_count[start_expert_id:start_expert_id+expert_size],
None,
None,
None,
None, tokens)
# add_bias_active
act_out = tmo.moe_active(gemm1_out, act_mode, gated, None, bias1, cusum_token_count, start_expert_id, expert_size)
if per_token_sq:
quant_input, input_scale = tmo.moe_quantize(act_out, act_smooth, None,
token_count[start_expert_id:start_expert_id+expert_size])
if cncl_comm > 0:
raise ValueError("not support communication and computing fusion currently.")
else:
if per_token_sq:
gemm2_out = tmo.smooth_quant_group_gemm(quant_input,
w2, token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, input_scale, w2_scale, dtype, tokens, quant_flag = w2_quant_flag)
else:
gemm2_out = tmo.group_gemm(act_out,
w2,
token_count[start_expert_id:start_expert_id+expert_size],
None, None, None, None, tokens)
output = tmo.moe_combine_result(gemm2_out, reduce_weight, combine_idx,
residual, cusum_token_count, start_expert_id,
expert_size, bias2)
return output.reshape(ori_input_shape)
def min_mem_size(shape, stride):
if stride is None:
mem_size = 0
mem_size += shape.numel()
else:
mem_size = 1
for k,v in zip(shape, stride):
mem_size += (k - 1) * v
return mem_size
def create_tensor(shape, dtype, is_contiguous, device, stride = None, mean=0, var=1, is_uniform=False, low=0, high=1):
if is_contiguous:
if dtype in (torch.int8, torch.uint8):
t = torch.randint(-128, 127, shape, device=device).to(dtype)
else:
if is_uniform:
t = torch.empty(shape, dtype=dtype, device=device).uniform_(low, high)
else:
t = torch.normal(mean, var, shape, dtype=dtype, device=device)
else:
mem_size = min_mem_size(shape, stride)
if dtype in (torch.int8, torch.uint8):
t = torch.randint(-128, 127, (mem_size,), device=device).to(dtype)
else:
if is_uniform:
t = torch.empty((mem_size,), dtype=dtype, device=device).uniform_(low, high)
else:
t = torch.normal(mean, var, (mem_size,), dtype=dtype, device=device)
t = t.as_strided(shape, stride)
return t
def create_tensor_from_dic(dic:dict, mean=0, var=1, is_uniform=False, low=0, high=1):
if dic['data'] is None:
return None
shape = dic['shape']
dtype = dic['dtype']
is_contiguous = dic['is_contiguous']
device = dic['device']
stride = dic['stride']
return create_tensor(shape, dtype, is_contiguous, device, stride, mean, var, is_uniform, low, high)
def create_op_param(dic: dict):
if dic['type'] in (list, tuple):
return [create_op_param(elem) for elem in dic['data']] if dic['has_compound'] else dic['data']
elif dic['type'] is dict:
return {k:create_op_param(v) for k,v in dic['data'].items()}
elif dic['type'] is torch.Tensor:
if dic['data'] is None:
return None
else:
if dic['dtype'] in (torch.int16, torch.int32, torch.int64):
return dic['data']
else:
return create_tensor(dic['shape'], dic['dtype'], dic['is_contiguous'], dic['device'], dic['stride'])
else:
return dic['data']

View File

@@ -0,0 +1,151 @@
import os
os.environ['TMO_GEN_CASE'] = '0'
import sys
sys_args = sys.argv
sys.argv = [sys_args.pop(0)] # prevent unittest printing help info
import copy
import argparse
import torch
import torch_mlu
import importlib
import random
from common_utils import create_op_param
def assert_tensor_equal(a: torch.Tensor, b: torch.Tensor, threshold):
assert a.size() == b.size()
a_ = a.cpu().reshape(-1).double()
b_ = b.cpu().reshape(-1).double()
nan_mask = a_ != a_
assert torch.equal(nan_mask, b_ != b_), "tensor a and tensor b have different number of nan"
diff = a_ - b_
diff[nan_mask] = 0
a_[nan_mask] = 0
b_[nan_mask] = 0
eps = 1e-10
diff1 = diff.abs().sum() / (a_.abs().sum() + eps)
diff2 = torch.sqrt((diff**2).sum() / ((a_**2).sum() + eps))
print(f"[torch_mlu_ops] diff1: {diff1}, diff2: {diff2}")
assert diff1 <= threshold, f"diff1: {diff1} <= threshold: {threshold}"
assert diff2 <= threshold, f"diff2: {diff2} <= threshold: {threshold}"
def check_value_equal(x, y, threshold):
assert type(x) == type(y)
if type(x) is torch.Tensor:
assert_tensor_equal(x, y, threshold)
elif type(x) is list or type(x) is tuple:
for i in range(len(x)):
check_value_equal(x[i], y[i], threshold)
elif x is not None:
assert x == y
def check_equal(a, b, threshold):
assert type(a) == type(b)
if type(a) is tuple or type(a) is list:
assert len(a) == len(b)
for x, y in zip(a, b):
check_value_equal(x, y, threshold)
else:
check_value_equal(a, b, threshold)
def get_base_obj(module_name, class_name):
try:
mod_ = importlib.import_module(module_name)
cls_ = getattr(mod_, class_name)
return cls_()
except ImportError as e:
print(f"Failed to import class '{class_name}' from module '{module_name}': {e}")
return None
except AttributeError as e:
print(f"Module '{module_name}' does not have a class named '{class_name}': {e}")
return None
def get_tmo_func(func_name):
import torch_mlu_ops as tmo
return getattr(tmo, func_name)
op_map = {
"active": ["test_active", "TestActive", 0.004],
"apply_rotary": ["test_apply_rotary", "TestApplyRotaryOp", 0.003],
"attention_project": ["test_attn_proj", "TestAttnProjOp", 0.003],
"batch_matmul": ["test_batch_matmul", "TestBatchMatMulOp", 0.004],
"copy_blocks": ["test_copy_blocks", "TestCopyBlocksOp", 0],
"dequant_from_linear_cache": ["test_dequant_from_linear_cache", "TestDequantFromLinearCache", 0.001],
"dequant_from_paged_cache": ["test_dequant_from_paged_cache", "TestDequantFromPagedCache", 0.001],
"ffn": ["test_ffn", "TestFFNOp", 0.005],
"flash_attention": ["test_flash_attention", "TestFlashAttnOp", 0.005],
# "flash_attn_sq_mm_allreduce": ["test_flash_attn_sq_mm_allreduce", "TestFlashAttnSqMMAllreduce", 0.003],
"fused_layer_norm": ["test_fused_layernorm", "TestFuseLayerNormOp", 0.003],
"fused_moe": ["test_moe", "TestFusedMOEOp", 0.006],
"fused_norm_attention_project": ["test_fused_attn_proj", "TestFusedNormAttnProjOp", 0.003],
"fused_norm_residual_ffn": ["test_fused_ffn", "TestFusedNormResidualFFNoP", 0.003],
"fused_rms_norm": ["test_fused_rmsnorm", "TestFuseRmsNormOp", 0.0032],
"fused_rope": ["test_fused_rope", "TestFusedRopeOp", 0.003],
"group_gemm": ["test_group_gemm", "TestGroupGemmOp", 0.006],
"matmul": ["test_matmul", "TestMatMulOp", 0.003],
# "matmul_allreduce": ["test_matmul_all_reduce", "TestMatMulAllReduceOp", 0.006],
"moe_active": ["test_moe_add_bias_activation", "TestMoeActiveKernel", 0.003],
"moe_cast_gating": ["test_moe_cast_gating", "TestMoeCastGating", 0.0001],
"moe_combine_result": ["test_moe_combine_result", "TestCombineResult", 0.003],
"moe_expand_input": ["test_moe_expand_input", "TestExpandInput", 0],
"moe_gen_idx": ["test_moe_gen_idx", "TestGenIdx", 0],
"moe_quantize": ["test_smooth_quant", "TestSmoothQuantOp", 0.01],
"moe_softmax_topk": ["test_moe_softmax_topk", "TestSoftmaxTopkOp", 0.003],
"offline_quant_to_linear_cache": ["test_offline_quant_to_linear_cache", "TestOfflineQuantToLinearCache", 0.03],
"offline_quant_to_paged_cache": ["test_offline_quant_to_paged_cache", "TestOfflineQuantToPagedCache", 0.03],
"per_token_smooth_quantize": ["test_per_token_smooth_quantize", "TestPerTokenSmoothQuantizeOp", 0.003],
# "preload": ["test_preload", "TestPreloadOp", 1],
"quant_to_linear_cache": ["test_quant_to_linear_cache", "TestQuantToLinearCache", 0.003],
"quant_to_paged_cache": ["test_quant_to_paged_cache", "TestQuantToPagedCache", 0.009],
"quantize": ["test_quantize", "TestQuantizeOp", 0.003],
"reshape_linear_cache": ["test_reshape_linear_cache", "TestReshapeLinearCache", 0],
"reshape_paged_cache": ["test_reshape_paged_cache", "TestReshapePagedCacheOp", 0],
"single_query_cached_kv_attn": ["test_single_query_cached_kv_attn", "TestSingleQueryAttnOp", 0.003],
"single_query_mixed_cached_kv_attn": ["test_single_query_mixed_cached_kv_attn", "TestSingleQueryMixedKVAttnOp", 0.003],
"smooth_quant_group_gemm": ["test_smooth_quant_group_gemm", "TestSmoothQuantGroupGemmOp", 0.006],
"smooth_quant_matmul": ["test_smooth_quant_matmul", "TestSmoothQuantMatmulOp", 0.006],
# "smooth_quant_matmul_allreduce": ["test_quant_matmul_all_reduce", "TestGptQuantMatmulOp", 0.006],
"swap_blocks": ["test_swap_blocks", "TestSwapBlocksOp", 0],
"update_out_and_lse": ["test_update_out_and_lse", "TestUpdateOutAndLse", 0.005],
"weight_only_quant_matmul": ["test_weight_only_quant_matmul", "TestWeightOnlyQuantMatmulOp", 0.004],
}
def run_case(pt_case):
op_name = pt_case.pop('op')
op_obj = get_base_obj(op_map[op_name][0], op_map[op_name][1])
if hasattr(op_obj, "run_gen_case"):
op_obj.run_gen_case(pt_case)
else:
dump_data = pt_case.pop('dump_data')
if dump_data:
params = pt_case
else:
params = dict()
for k,v in pt_case.items():
params[k] = create_op_param(v)
params_bak = copy.deepcopy(params)
result_tmo = get_tmo_func(op_name)(**params)
result_base = op_obj.op_impl_base(*params_bak.values())
check_equal(result_tmo, result_base, op_map[op_name][-1])
def main():
random.seed(0)
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('--case_path', required=True, type=str, help='specify the case path')
parser.add_argument('--detail', action="store_true", help="show content of pt file")
args = parser.parse_args(args=sys_args)
pt_case = torch.load(args.case_path)
if args.detail:
for k,v in pt_case.items():
print(f"{k}: {v}")
exit(0)
print(f"[torch_mlu_ops] run {args.case_path} ...")
run_case(pt_case)
print(f"[torch_mlu_ops] run {args.case_path} successfully")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,22 @@
#!/bin/bash
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
tmo_ops_case=$(find "${SCRIPT_DIR}" -name "test*.py")
coverage=${1}
for sc in ${tmo_ops_case}
do
echo -n "${sc} "
echo -n "Testing...."
if [ "${coverage}" = "coverage" ];then
coverage run -a ${sc}
else
python3 "${sc}" > "/tmp/$(basename ${sc}).log" 2>&1
fi
if [ $? == 0 ];then
echo -e "\033[32m success \033[0m"
else
echo -e "\033[31m failed \033[0m"
fi
done
echo "End of pytest..."

View File

@@ -0,0 +1,88 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
import time
import os
class TestActive(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
input = create_tensor_from_dic(dic['input'])
act_mode = dic['act_mode']['data']
is_gated = dic['is_gated']['data']
active_coef = dic['active_coef']['data']
self.launch(input, act_mode, is_gated, active_coef)
def launch(self, *args):
torch_out = self.op_impl_base(*args)
tmo_out = tmo.active(*args)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
def op_impl_base(self, *args):
input, act_mode, is_gated, active_coef = args
channel = input.size(-1)
if act_mode == "gelu":
if is_gated:
out = torch.nn.functional.gelu(input[..., :channel//2])
out *= input[..., channel//2:]
else:
out = torch.nn.functional.gelu(input)
else:
if act_mode == "silu":
active_coef = 1.0
elif act_mode == "quick_gelu":
active_coef = 1.702
def swish(input, coef):
return input * torch.sigmoid(coef * input)
if is_gated:
out = swish(input[..., :channel//2], active_coef) * input[..., channel//2:]
else:
out = swish(input, active_coef)
return out
def test_active_random(self):
for _ in range(500):
dtype_list = [torch.float, torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
input_dtype = random.choice(dtype_list)
batch = random.randint(1, 10)
seq = random.randint(1, 2048)
hidden_size = random.randrange(2, 8192, 2)
is_gated = random.choice([True, False])
act_mode = random.choice(['gelu', 'silu', 'quick_gelu', 'swish'])
if act_mode == 'silu':
active_coef = 1.0
elif act_mode == 'quick_gelu':
active_coef = 1.702
else:
active_coef = random.uniform(0, 1)
print("input_shape: {}, is_gated: {}, act_mode: {}, dtype: {} testing...".format( \
[batch, seq, hidden_size], is_gated, act_mode, input_dtype), flush=True)
input = torch.randn(batch, seq, hidden_size, dtype=input_dtype, device="mlu")
self.launch(input, act_mode, is_gated, active_coef)
def test_inductor(self):
input = torch.randn(3, 4, 12, dtype=torch.half, device="mlu")
output = torch.empty(3, 4, 12, dtype=torch.half, device="mlu")
is_gated = True
act_mode = 'silu'
coef = 1.0
args = (input, output, None, None, act_mode, is_gated, 0, 0, coef)
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
if __name__ == '__main__':
exit(run_unittest(TestActive))

View File

@@ -0,0 +1,152 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
def gen_args(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype):
cu_context_lens = None
total_seq_len = bs * seq_len
max_context_len = seq_len
if packed:
context_lens = torch.randint(size=(bs, ), low=1, high=seq_len+1, dtype=torch.int32, device='mlu')
total_seq_len = context_lens.sum().item()
max_context_len = context_lens.max().item()
cu_context_lens = torch.cumsum(context_lens, dim=-1, dtype=torch.int32)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0)
context_shape = (total_seq_len, q_heads + kv_heads + 1, head_size) if packed else \
(bs, seq_len, q_heads + kv_heads + 1, head_size)
context = torch.randn(size=context_shape, dtype=dtype).mlu()
qk = context[..., 0 : q_heads + kv_heads, :]
position_id = None
if discrete:
position_id = torch.randint(0, max_context_len, size=(total_seq_len,), dtype=torch.int32, device="mlu")
else:
position_id = torch.randint(0, max_context_len, size=(bs,), dtype=torch.int32, device="mlu")
rope_seqlen = seq_len * 2
cos_shape = (bs, rope_seqlen, rope_dim) if dynamic_ntk else (rope_seqlen, rope_dim)
cos_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu")
sin_cache = torch.randn(size=cos_shape, dtype=dtype, device="mlu")
return (qk, sin_cache, cos_cache, position_id, cu_context_lens, \
interleaved, discrete, dynamic_ntk, max_context_len)
class TestApplyRotaryOp(BtTestCase):
def op_impl_base(self, *args):
def rotate(x: torch.Tensor, interleaved: bool):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
y = torch.empty_like(x)
x1, x2 = x[..., ::2], x[..., 1::2]
y[..., ::2], y[..., 1::2] = -x2, x1
return y
input, sin_cache, cos_cache, position_ids, cu_seqlen, interleaved, discrete, dynamic_ntk, max_seqlen = args
packed = input.dim() == 3
rope_dim = sin_cache.shape[-1]
batch_size = cu_seqlen.shape[0] - 1 if packed else input.shape[0]
sin_cache_float = sin_cache.float()
cos_cache_float = cos_cache.float()
cu_seqlen_cpu = cu_seqlen.cpu() if cu_seqlen is not None else None
position_ids_cpu = position_ids.cpu() if position_ids is not None else None
for i in range(batch_size):
input_i = input[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]] if packed else input[i]
input_i = input_i[..., 0:rope_dim]
sin_cache_i = sin_cache_float[i] if dynamic_ntk else sin_cache_float
cos_cache_i = cos_cache_float[i] if dynamic_ntk else cos_cache_float
seq = input_i.shape[0]
if discrete:
if packed:
position_id_i = position_ids_cpu[cu_seqlen_cpu[i] : cu_seqlen_cpu[i + 1]]
else:
position_id_i = position_ids_cpu.view(batch_size, -1)[i]
sin_cache_i = sin_cache_i[position_id_i]
cos_cache_i = cos_cache_i[position_id_i]
else:
if position_ids_cpu is None:
sin_cache_i = sin_cache_i[:seq]
cos_cache_i = cos_cache_i[:seq]
else:
pos_id = position_ids_cpu[i].item()
sin_cache_i = sin_cache_i[pos_id : seq + pos_id]
cos_cache_i = cos_cache_i[pos_id : seq + pos_id]
rot = rotate(input_i.float(), interleaved)
output_i = rot * sin_cache_i.unsqueeze(1) + input_i * cos_cache_i.unsqueeze(1)
input_i[:] = output_i.to(input.dtype)
return input
def test_apply_rotary(self):
bs_list = [1, 8]
seq_len_list = [1, 128, 1024]
q_heads_list = [8, 32]
kv_heads_list = [1]
head_size_list = [128, 256]
rope_dim_list = [256, 128, 64, 24]
is_interleaved_list = [True, False]
discrete_list = [True, False]
dynamic_ntk_list = [True, False]
packed_list = [True, False]
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
args = product(bs_list, seq_len_list, q_heads_list, kv_heads_list, head_size_list, \
rope_dim_list, is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list, dtype_list)
for bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype in args:
print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, "
"rope_dim: {}, interleaved: {}, discrete: {}, dynamic_ntk: {}, packed: {}, dtype: {} testing...".format( \
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype), flush=True)
if rope_dim > head_size:
print("rope_dim = {}, head_size = {},rope_dim should less than head_size".format(rope_dim, head_size))
continue
qk, sin_cache, cos_cache, position_id, cu_context_lens, \
interleaved, _, _, max_context_len = gen_args(bs, seq_len, q_heads, kv_heads, head_size,
rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype)
qk_base = qk.clone()
self.op_impl_base(qk_base, sin_cache, cos_cache, position_id, cu_context_lens, \
interleaved, discrete, dynamic_ntk, max_context_len)
qk_out = ops.apply_rotary(qk, sin_cache, cos_cache, position_id, cu_context_lens, \
interleaved, discrete, dynamic_ntk, max_context_len)
self.assertTensorsEqual(qk_out.cpu().float(), qk_base.cpu().float(), 0.003, use_MSE=True, use_RAE=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
bs, seq, head_num, head_size, dynamic_ntk, rotary_seqlen, rotary_dim, discrete = 1, 1024, 8, 128, False, 512, 128, True
q = torch.randn(bs, seq, head_num, head_size, dtype=torch.half, device="mlu")
sin = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu")
cos = torch.randn(rotary_seqlen, rotary_dim, dtype=torch.half, device="mlu")
self.assertException("discrete must be false if position ids is null.", ops.apply_rotary,
q, sin, cos, None, None, False, discrete, dynamic_ntk, seq)
discrete = False
self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary,
q, sin, cos, None, None, False, discrete, dynamic_ntk, seq)
position_ids = torch.zeros(bs, dtype=torch.int32, device="mlu")
self.assertException("max_seqlen must less than or equal to rope_seqlen.", ops.apply_rotary,
q, sin, cos, position_ids, None, False, discrete, dynamic_ntk, seq)
def test_inductor(self):
is_interleaved_list = [True, False]
discrete_list = [True, False]
dynamic_ntk_list = [True, False]
packed_list = [True, False]
params = product(is_interleaved_list, discrete_list, dynamic_ntk_list, packed_list)
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype = 8, 1024, 8, 1, 256, 24, torch.float16
for interleaved, discrete, dynamic_ntk, packed in params:
print(f"==== check apply_rotary interleaved: {interleaved}, discrete: {discrete}, dynamic_ntk: {dynamic_ntk}, packed: {packed} ====")
args = gen_args(bs, seq_len, q_heads, kv_heads, head_size,
rope_dim, interleaved, discrete, dynamic_ntk, packed, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.apply_rotary, args)
if __name__ == '__main__':
exit(run_unittest(TestApplyRotaryOp))

View File

@@ -0,0 +1,44 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from torch.nn.parameter import Parameter
from torch.nn import functional as F
class TestAttnProjOp(BtTestCase):
def op_impl_base(self, *args):
input, weight, bias, residual, alpha, beta = args
proj = F.linear(input, weight, bias)
output = alpha * proj + beta * residual
return output
def test_attn_proj(self):
N, T, input_size, hidden_size, alpha, beta = 32, 129, 2048, 4096, 0.5, 0.1
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
N, T, input_size, hidden_size), flush=True)
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
residual = torch.randn(N, T, hidden_size * 3, dtype=dtype, device="mlu")
torch_out = self.op_impl_base(input, weight, bias, residual, alpha, beta)
tmo_out = ops.attention_project(input, weight, bias, residual, alpha, beta)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
N, T, input_size, hidden_size, dtype = 32, 129, 2048, 4096, torch.half
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
args = (input, weight, bias, None, None, None,
None, None, None, None, "nthc", 1,
1e-5, 1., 0., False)
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
if __name__ == '__main__':
exit(run_unittest(TestAttnProjOp))

View File

@@ -0,0 +1,83 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
import numpy as np
class TestBatchMatMulOp(BtTestCase):
def op_impl_base(self, *args):
a, b, c, alpha, beta, a_scale, b_scale, trans_a, trans_b = args
if trans_a:
a = a.transpose(1, 2)
if trans_b:
b = b.transpose(1, 2)
if c is None:
output_dtype = a.dtype
else:
output_dtype = c.dtype
if a_scale is not None:
a = torch.div(a, a_scale).to(output_dtype)
b = torch.div(b, b_scale).to(output_dtype)
output = alpha * torch.bmm(a, b)
if c is not None:
output += beta * c
c.copy_(output)
return output
def test_batch_matmul(self):
batch_list = [5]
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [128]
has_res_list = [False, True]
trans_a_list = [False, True]
trans_b_list = [False, True]
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
alpha = 0.625
beta = 1.0
args = product(batch_list, mat_m_list, mat_n_list, mat_k_list, has_res_list, dtype_list, trans_a_list, trans_b_list)
for batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b in args:
print("batch={}, m={}, n={}, k={}, has_res={}, dtype={}, trans_a={}, trans_b={} testing...".format(
batch, mat_m, mat_n, mat_k, has_res, dtype, trans_a, trans_b), flush=True)
shape_a, shape_b = (batch, mat_m, mat_k), (batch, mat_k, mat_n)
if trans_a:
shape_a = (batch, mat_k, mat_m)
if trans_b:
shape_b = (batch, mat_n, mat_k)
input = torch.randn(shape_a, dtype=dtype, device='mlu')
weight = torch.randn(shape_b, dtype=dtype, device='mlu')
input8, a_scale = QuantByTensor(input, 8)
weight8, b_scale = QuantByTensor(weight, 8)
residual = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu')
res_bak = residual.clone()
tmo_output_int8 = torch.zeros((batch, mat_m, mat_n), dtype=dtype, device='mlu')
torch_output_int8 = tmo_output_int8.clone()
output = self.op_impl_base(input, weight,
residual if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b)
tmo_output = ops.batch_matmul(input, weight,
res_bak if has_res else None, alpha, beta, 1.0, 1.0, trans_a, trans_b)
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
if dtype != torch.bfloat16:
self.op_impl_base(input8, weight8, torch_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b)
ops.batch_matmul(input8, weight8, tmo_output_int8, alpha, beta, a_scale.item(), b_scale.item(), trans_a, trans_b)
self.assertTensorsEqual(tmo_output_int8.cpu().float(), torch_output_int8.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
def test_inductor(self):
batch, mat_m, mat_n, mat_k, alpha, beta, dtype = 6, 64, 256, 128, 0.8, 0.3, torch.float16
a = torch.randn((batch, mat_m, mat_k), dtype=dtype, device='mlu')
b = torch.randn((batch, mat_n, mat_k), dtype=dtype, device='mlu')
c = torch.randn((batch, mat_m, mat_n), dtype=dtype, device='mlu')
args = (a, b, c, alpha, beta, 1.0, 1.0, False, True)
self.base_opcheck(torch.ops.torch_mlu_ops.batch_matmul, args)
if __name__ == '__main__':
exit(run_unittest(TestBatchMatMulOp))

View File

@@ -0,0 +1,167 @@
import math
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
from typing import List, Tuple
import os
import copy
class TestCopyBlocksOp(BtTestCase):
def create_kv_caches(
self,
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
seed: int
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = head_size**-0.5
# vllm scale
# x = 16 // torch.tensor([], dtype=dtype).element_size()
# key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache_shape = (num_blocks, num_heads, block_size, head_size)
key_caches = []
for _ in range(num_layers):
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
info = torch.iinfo(dtype)
key_cache = torch.randint(info.min, info.max, size=key_cache_shape, dtype=dtype).mlu()
else:
key_cache = torch.empty(size=key_cache_shape, dtype=dtype).mlu()
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
info = torch.iinfo(dtype)
value_cache = torch.randint(info.min, info.max, size=value_cache_shape, dtype=dtype).mlu()
else:
value_cache = torch.empty(size=value_cache_shape, dtype=dtype).mlu()
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
def create_block_mapping(self, num_blocks, num_mappings, seed = 0):
random.seed(seed)
torch.random.manual_seed(seed)
assert 3 * num_mappings <= num_blocks
block_mapping = {}
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
return block_mapping
def op_impl_base(self, *args):
k_caches, v_caches, block_mapping = args
for src, dsts in block_mapping.items():
srcs = [src for i in range(len(dsts))]
srcs_ind = torch.tensor(srcs, dtype=torch.int64)
dsts_ind = torch.tensor(dsts, dtype=torch.int64)
for key_cache in k_caches:
key_cache[dsts_ind] = key_cache[srcs_ind]
if v_caches is not None:
for value_cache in v_caches:
value_cache[dsts_ind] = value_cache[srcs_ind]
return (k_caches, v_caches) if v_caches is not None else k_caches
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_copy_blocks due to ASan issues")
def test_copy_blocks(self):
dtype_list = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
num_tokens_list = [83]
num_heads_list = [8]
head_size_list = [64, 512]
num_blocks_list = [3600]
block_size_list = [8]
num_layers_list = [1, 6]
num_mappings_list = [128, 600]
seeds_list = [0]
only_key_cache_list = [True, False]
args = product(num_tokens_list, num_heads_list, head_size_list, num_blocks_list, block_size_list, dtype_list,
num_layers_list, num_mappings_list, seeds_list)
for num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed in args:
print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, dtype: {}, num_layers: {}, \
num_mappings: {}, seed: {}, testing...".format(
num_tokens, num_heads, head_size, num_blocks, block_size, dtype, num_layers, num_mappings, seed), flush=True)
block_mapping = self.create_block_mapping(num_blocks, num_mappings, seed)
only_key_cache = random.choice(only_key_cache_list)
# Create the KV caches.
key_caches, value_caches = self.create_kv_caches(num_blocks, block_size,
num_layers, num_heads,
head_size, dtype, seed)
# Clone the KV caches.
cloned_key_caches = [key_cache.cpu().clone() for key_cache in key_caches]
cloned_value_caches = [value_cache.cpu().clone() for value_cache in value_caches]
# Call the copy blocks kernel.
if only_key_cache:
value_caches = None
cloned_value_caches = None
ops.copy_blocks(key_caches, value_caches, block_mapping)
self.op_impl_base(cloned_key_caches, cloned_value_caches, block_mapping)
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
self.assertTensorsEqual(key_cache.cpu().float(), cloned_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
if not only_key_cache:
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
self.assertTensorsEqual(value_cache.cpu().float(), cloned_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues")
def test_prevent(self):
print("test_copy_block: test_prevent...")
num_blocks, block_size, head_num, head_size, block_mapping = 384, 6, 32, 128, 128
k_cache = torch.randn(num_blocks * head_num, block_size, head_size, dtype=torch.half, device="mlu")
v_cache = torch.randn(num_blocks * head_num, head_size, block_size + 1, dtype=torch.float, device="mlu")
key_caches = [k_cache,]
value_caches = None
block_mapping = self.create_block_mapping(num_blocks, block_mapping)
self.assertException("every layer k_cache must be 4d.", ops.copy_blocks,
key_caches, value_caches, block_mapping)
k_cache = k_cache.reshape(num_blocks, head_num, block_size, head_size)
key_caches = [k_cache, k_cache,]
value_caches = [v_cache]
self.assertException("k_caches size must equal to v_caches size if v_caches is not none.",
ops.copy_blocks, key_caches, value_caches, block_mapping)
key_caches = [k_cache]
value_caches = [v_cache]
self.assertException("the data type of k_caches and v_caches are not the same.",
ops.copy_blocks, key_caches, value_caches, block_mapping)
value_caches[0] = value_caches[0].to(torch.half)
self.assertException("every layer k_cache dim must equal to v_cache dim.",
ops.copy_blocks, key_caches, value_caches, block_mapping)
value_caches[0] = value_caches[0].reshape(num_blocks, head_num, head_size, block_size + 1)
self.assertException("the block_size of k_caches and v_caches are not the same.",
ops.copy_blocks, key_caches, value_caches, block_mapping)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_inductor due to ASan issues")
def test_inductor(self):
num_heads, head_size, num_blocks, block_size, num_layers, num_mappings = 8, 64, 384, 8, 1, 128
key_caches, value_caches = self.create_kv_caches(num_blocks, block_size,
num_layers, num_heads,
head_size, torch.float16, 0)
block_mapping = self.create_block_mapping(num_blocks, num_mappings)
self.base_opcheck(torch.ops.torch_mlu_ops.copy_blocks, (key_caches, value_caches, block_mapping))
if __name__ == '__main__':
exit(run_unittest(TestCopyBlocksOp))

View File

@@ -0,0 +1,280 @@
import random
import unittest
import numpy as np
import math
import torch
import torch_mlu_ops as ops
from common_utils import *
def gen_args(max_batch_size,
batch_size,
max_context_len,
head_num_q,
head_num_kv,
cache_mem_len,
head_size,
group_size,
use_seq_offset,
dtype,
quant_mode,
quant_bit,
has_value = True):
# Preprocess arguments
assert max_batch_size >= batch_size, \
"max_batch_size should greater than or equal to batch_size."
assert cache_mem_len >= max_context_len, \
"cache_mem_len should greater then or equal to max_context_len."
assert head_size % group_size == 0, \
"head_size should be a multiply of groupwise."
total_heads = head_num_q + head_num_kv * 2
max_seq_offset = cache_mem_len - max_context_len
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_seq_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache_bs_id = torch.IntTensor(random.sample([*range(0, batch_size + 1)], batch_size)).mlu()
cache_seq_offset = torch.randint(low=-1, high=max_seq_offset, size=[batch_size],
dtype=torch.int32, device="mlu")
if quant_bit == 4:
key_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len, head_size // 2), device="mlu")
value_cache = torch.randint(low=-128, high=127, dtype=torch.int32, size=(max_batch_size,
head_num_kv, cache_mem_len // 2, head_size), device="mlu")
key_cache, value_cache = key_cache.to(torch.int8), value_cache.to(torch.int8)
else:
cache = torch.randint(size=(2, max_batch_size, head_num_kv, cache_mem_len, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, max_batch_size, head_num_kv, cache_mem_len),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
# Prepare arguments
if has_value == False:
value = None
value_cache = None
value_cache_scale = None
args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale]
args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None,
cache_bs_id, cache_seq_offset]
args += [quant_mode, quant_bit]
return args
class TestDequantFromLinearCache(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
key = create_tensor_from_dic(dic['key'])
value = create_tensor_from_dic(dic['value'])
key_cache = create_tensor_from_dic(dic['key_cache'])
value_cache = create_tensor_from_dic(dic['value_cache'])
key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale'])
value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale'])
context_lengths = dic['context_lengths']['data']
max_context_len = dic['max_context_len']['data']
context_seq_offset = dic['context_seq_offset']['data']
cache_bs_id = dic['cache_bs_id']['data']
cache_seq_offset = dic['cache_seq_offset']['data']
quant_mode = dic['quant_mode']['data']
quant_bit = dic['quant_bit']['data']
self.launch(key, value, key_cache, value_cache, key_cache_quant_scale,
value_cache_quant_scale, context_lengths, max_context_len,
context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode, quant_bit)
def launch(self, *args):
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
context_lengths, max_context_len, context_seq_offset, cache_bs_id, \
cache_seq_offset, quant_mode, quant_bit = args
if value is None or value_cache is None or value_cache_scale is None:
has_value = False
else:
has_value = True
ops.dequant_from_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lengths, max_context_len,
context_seq_offset, cache_bs_id, cache_seq_offset,
quant_mode, quant_bit)
key_clone = key.clone()
if has_value:
value_clone = value.clone()
self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lengths, max_context_len,
context_seq_offset, cache_bs_id, cache_seq_offset, quant_mode,
quant_bit)
self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True)
if has_value:
self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001,
use_MSE=True)
def op_impl_base(self, *args):
def dequant_from_cache(quant_data: torch.Tensor,
scale_data: torch.Tensor,
quant_mode: int):
quant_data_fp32 = quant_data.clone().to(torch.float)
scale_data_fp32 = scale_data.clone().to(torch.float)
if quant_mode == 0: # per channel
scale_data_fp32 = scale_data[..., None, :]
else: # per head/token
scale_data_fp32 = scale_data[..., None]
dequant_data_fp32 = quant_data_fp32 * scale_data_fp32
return dequant_data_fp32
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
context_lengths, max_context_len, context_seq_offset, cache_bs_id, \
cache_seqlen_offset, quant_mode, quant_bit = args
batch_size = context_lengths.size(0)
if context_seq_offset is None:
cu_seq_offset = torch.cumsum(context_lengths, dim=-1)
context_seq_offset = torch.zeros_like(cu_seq_offset)
context_seq_offset[1:] = cu_seq_offset[:-1]
total = 0
cache_mem_len = key_cache.size(2)
for i in range(batch_size):
context_len = context_lengths[i].item()
seq_begin = context_seq_offset[i].item()
seq_end = seq_begin + context_len
total += context_len
cache_id = cache_bs_id[i] if cache_bs_id is not None else i
cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
cache_seq_end = cache_seq_begin + context_len
key_i = key[seq_begin:seq_end].transpose(1, 0)
if quant_bit == 4:
key_cache_i_temp = key_cache[cache_id, :, cache_seq_begin:cache_seq_end]
cache_size = list(key_cache_i_temp.size())
cache_size[-1] *= 2
key_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu")
key_cache_i[...,::2] = key_cache_i_temp << 4 >> 4
key_cache_i[...,1::2] = key_cache_i_temp >> 4
else:
key_cache_i = key_cache[cache_id, :, cache_seq_begin:cache_seq_end]
# We use negatice cache_seq_offset to skip unused batch
if cache_seq_begin < 0 or cache_seq_end > cache_mem_len:
continue
# dequant key from cache
if quant_mode == 0:
key_cache_scale_i = key_cache_scale
else:
key_cache_scale_i = key_cache_scale[cache_id, :, cache_seq_begin:cache_seq_end]
dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode)
key_i[...] = dequant_key_i.to(key_i.dtype)
# dequant value from cache
if not (value_cache is None or value is None or value_cache_scale is None):
value_i = value[seq_begin:seq_end].transpose(1, 0)
if quant_bit == 4:
pad_front = cache_seq_begin % 2
pad_back = cache_seq_end % 2
cache_seq_begin_temp = int(cache_seq_begin // 2)
cache_seq_end_temp = int(math.ceil(cache_seq_end / 2.0))
value_cache_i_temp = value_cache[cache_id, :,
cache_seq_begin_temp:cache_seq_end_temp]
cache_size = list(value_cache_i_temp.size())
cache_size[-2] *= 2
value_cache_i = torch.zeros(cache_size, dtype=torch.int8, device="mlu")
value_cache_i[...,::2,:] = value_cache_i_temp << 4 >> 4
value_cache_i[...,1::2,:] = value_cache_i_temp >> 4
if pad_front:
value_cache_i = value_cache_i[...,1:,:]
if pad_back:
value_cache_i = value_cache_i[...,:-1,:]
else:
value_cache_i = value_cache[cache_id, :, cache_seq_begin:cache_seq_end]
if quant_mode == 0:
value_cache_scale_i = value_cache_scale
else:
value_cache_scale_i = value_cache_scale[cache_id, :,
cache_seq_begin:cache_seq_end]
dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i,
quant_mode)
value_i[...] = dequant_value_i.to(value.dtype)
def test_dequant_from_linear_cache(self):
test_cases = 100
head_size_times = 16
mlu_name = torch.mlu.get_device_name()
max_batch_size_list = torch.randint(low=32, high=64, size=[test_cases],
dtype=torch.int32)
batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
max_context_len_list = torch.randint(low=2, high=2048, size=[test_cases],
dtype=torch.int32)
head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
head_num_kv_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=[test_cases],
dtype=torch.int32) * head_size_times
cache_mem_len_list = torch.randint(low=1024, high=2048, size=[test_cases],
dtype=torch.int32) * 2
quant_mode_list = np.random.choice([0, 1], test_cases)
quant_bit_list = np.random.choice([4, 8], test_cases)
use_offset_list = np.random.choice([False, True], test_cases)
has_value_list = np.random.choice([False, True], test_cases)
dtype_list = [torch.half, torch.bfloat16]
dtype_list = dtype_list[:-1] if "MLU3" in mlu_name else dtype_list
dtype_list = np.random.choice(dtype_list, test_cases)
for i in range(test_cases):
max_batch_size = max_batch_size_list[i].item()
batch_size = batch_size_list[i].item()
head_num_q = head_num_q_list[i].item()
max_context_len = max_context_len_list[i].item()
head_num_kv = head_num_kv_list[i].item()
head_size = head_size_list[i].item()
cache_mem_len = cache_mem_len_list[i].item()
quant_mode = quant_mode_list[i]
quant_bit = quant_bit_list[i]
use_seq_offset = use_offset_list[i]
has_value = has_value_list[i]
dtype = dtype_list[i]
if "MLU3" in mlu_name and (2 * cache_mem_len * max_batch_size * head_num_kv \
* head_size >= 2**31 - 1):
print("large tensor is not support on {}, skip".format(mlu_name))
continue
print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, "
"quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size,
max_context_len, quant_mode, quant_bit, dtype))
torch.manual_seed(2766)
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
cache_mem_len, head_size, head_size, use_seq_offset, dtype, quant_mode,
quant_bit, has_value)
self.launch(*args)
def test_inductor(self):
max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
head_size, dtype = 16, 8, 1024, 16, 32, 2048, 128, torch.float16
quant_mode, quant_bit = 0, 8
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
cache_mem_len, head_size, head_size, 1, dtype, quant_mode, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args)
args = gen_args(max_batch_size, batch_size, max_context_len, head_num_q, head_num_kv,
cache_mem_len, head_size, head_size, 0, dtype, quant_mode, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_linear_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestDequantFromLinearCache))

View File

@@ -0,0 +1,286 @@
import random
import unittest
import numpy as np
import math
import torch
import torch_mlu_ops as ops
from common_utils import *
def gen_args(batch_size,
max_context_len,
head_num_q,
head_num_kv,
cache_mem_len,
head_size,
group_size,
block_size,
use_seq_offset,
dtype,
quant_mode,
quant_bit,
has_value = True):
# Preprocess arguments
assert cache_mem_len >= max_context_len, \
"cache_mem_len should greater then or equal to max_context_len."
assert head_size % group_size == 0, \
"head_size should be a multiply of groupwise."
total_heads = head_num_q + head_num_kv * 2
max_seq_offset = cache_mem_len - max_context_len
max_block_num = int(math.ceil(max_context_len / block_size))
total_blocks = int(math.ceil(cache_mem_len / block_size)) * batch_size
block_tables = random.sample(range(0, total_blocks), batch_size * max_block_num)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch_size,
max_block_num)
# Generates key and cache from context
context_lens = torch.randint(size=[batch_size], low=1, high=max_context_len + 1,
dtype=torch.int32, device="mlu")
if use_seq_offset:
context_paddings = torch.randint(size=[batch_size], low=0, high=max_seq_offset,
dtype=torch.int32, device="mlu")
else:
context_paddings = torch.zeros_like(context_lens)
cu_context_lens = torch.cumsum(context_lens + context_paddings, dim=-1)
total_seqlen = cu_context_lens[-1]
context_seq_offset = torch.zeros([batch_size], dtype=torch.int32, device="mlu")
context_seq_offset[1:] = cu_context_lens[:-1]
context = torch.randn([total_seqlen, total_heads, head_size], dtype=dtype, device="mlu")
key = context[..., head_num_q:head_num_q + head_num_kv, :]
value = context[..., head_num_q + head_num_kv:head_num_q + 2 * head_num_kv, :]
# Generates key_cache and value_cache
cache = torch.randint(size=(2, total_blocks, head_num_kv, block_size, head_size),
low=-128, high=127, dtype=torch.int32, device="mlu")
cache = cache.to(torch.int8)
key_cache, value_cache = cache[[0, 1]]
# Generates key_cache_scale and value_cache_scale
if quant_mode == 0: # quant_mode == 0 is per channel
cache_scale = torch.randn((2, head_num_kv, head_size), dtype=torch.float, device="mlu")
else: # quant_mode != 1 (== 1 for extend) is per head
cache_scale = torch.randn((2, total_blocks, head_num_kv, block_size),
dtype=torch.float, device="mlu")
key_cache_scale, value_cache_scale = cache_scale[[0, 1]]
# Prepare arguments
if has_value == False:
value = None
value_cache = None
value_cache_scale = None
args = [key, value, key_cache, value_cache, key_cache_scale, value_cache_scale]
args += [context_lens, max_context_len, context_seq_offset if use_seq_offset else None,
block_tables]
args += [quant_mode, quant_bit]
return args
class TestDequantFromPagedCache(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
key = create_tensor_from_dic(dic['key'])
value = create_tensor_from_dic(dic['value'])
key_cache = create_tensor_from_dic(dic['key_cache'])
value_cache = create_tensor_from_dic(dic['value_cache'])
key_cache_quant_scale = create_tensor_from_dic(dic['key_cache_quant_scale'])
value_cache_quant_scale = create_tensor_from_dic(dic['value_cache_quant_scale'])
context_lengths = dic['context_lengths']['data']
max_context_len = dic['max_context_len']['data']
context_seq_offset = dic['context_seq_offset']['data']
block_tables = dic['block_tables']['data']
quant_mode = dic['quant_mode']['data']
quant_bit = dic['quant_bit']['data']
self.launch(key, value, key_cache, value_cache, key_cache_quant_scale,
value_cache_quant_scale, context_lengths, max_context_len,
context_seq_offset, block_tables, quant_mode, quant_bit)
def launch(self, *args):
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
context_lengths, max_context_len, context_seq_offset, block_tables, \
quant_mode, quant_bit = args
if value is None or value_cache is None or value_cache_scale is None:
has_value = False
else:
has_value = True
ops.dequant_from_paged_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lengths, max_context_len,
context_seq_offset, block_tables, quant_mode, quant_bit)
key_clone = key.clone()
if has_value:
value_clone = value.clone()
self.op_impl_base(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lengths, max_context_len,
context_seq_offset, block_tables, quant_mode, quant_bit)
self.assertTensorsEqual(key_clone.cpu().float(), key.cpu().float(), 0.001, use_MSE=True)
if has_value:
self.assertTensorsEqual(value_clone.cpu().float(), value.cpu().float(), 0.001,
use_MSE=True)
def op_impl_base(self, *args):
def dequant_from_cache(quant_data: torch.Tensor,
scale_data: torch.Tensor,
quant_mode: int):
quant_data_fp32 = quant_data.clone().to(torch.float)
scale_data_fp32 = scale_data.clone().to(torch.float)
if quant_mode == 0: # per channel [head_num, 1, head_size]
scale_data_fp32 = scale_data[..., None, :]
else: # per head/token [head_num, context_len, 1]
scale_data_fp32 = scale_data[..., None]
dequant_data_fp32 = quant_data_fp32 * scale_data_fp32
return dequant_data_fp32
key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, \
context_lengths, max_context_len, context_seq_offset, block_tables, \
quant_mode, quant_bit = args
batch_size = context_lengths.size(0)
if context_seq_offset is None:
cu_seq_offset = torch.cumsum(context_lengths, dim=-1)
context_seq_offset = torch.zeros_like(cu_seq_offset)
context_seq_offset[1:] = cu_seq_offset[:-1]
total_seqlen = 0
block_size = key_cache.size(2)
for i in range(batch_size):
context_len = context_lengths[i].item()
seq_begin = context_seq_offset[i].item()
seq_end = seq_begin + context_len
total_seqlen += context_len
full_block_num = context_len // block_size
rem_token_num = context_len % block_size
key_i = key[seq_begin:seq_end].transpose(1, 0)
# [head_num, seq_num, head_size]
key_cache_i = torch.concat(
[key_cache[block_tables[i, j], ...] for j in range(full_block_num)] +
([key_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \
if rem_token_num > 0 else []), dim=-2
)
# dequant key from cache
if quant_mode == 0:
# [head_num, head_size]
key_cache_scale_i = key_cache_scale
else:
# [head_num, seq_num]
key_cache_scale_i = torch.concat(
[key_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] +
([key_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \
if rem_token_num > 0 else []), dim=-1
)
dequant_key_i = dequant_from_cache(key_cache_i, key_cache_scale_i, quant_mode)
key_i[...] = dequant_key_i.to(key_i.dtype)
# dequant value from cache
if not (value_cache is None or value is None or value_cache_scale is None):
value_i = value[seq_begin:seq_end].transpose(1, 0)
# [head_num, seq_num, head_size]
value_cache_i = torch.concat(
[value_cache[block_tables[i, j], ...] for j in range(full_block_num)] +
([value_cache[block_tables[i, full_block_num], :, :rem_token_num, :]] \
if rem_token_num > 0 else []), dim=-2
)
# dequant value from cache
if quant_mode == 0:
# [head_num, head_size]
value_cache_scale_i = value_cache_scale
else:
# [head_num, seq_num]
value_cache_scale_i = torch.concat(
[value_cache_scale[block_tables[i, j],...] for j in range(full_block_num)] +
([value_cache_scale[block_tables[i, full_block_num], :, :rem_token_num]] \
if rem_token_num > 0 else []), dim=-1
)
dequant_value_i = dequant_from_cache(value_cache_i, value_cache_scale_i, quant_mode)
value_i[...] = dequant_value_i.to(value_i.dtype)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ or "MLU3" in torch.mlu.get_device_name(),
"Skipping test_prevent due to ASan issues or in MLU300 series.")
def test_prevent(self):
batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16
dtype, quant_mode, quant_bit = torch.float16, 0, 8
default_args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit)
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 1, torch.float32, quant_mode, quant_bit)
self.assertException("Tensor key type should be half or bfloat16.",
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 1, dtype, 2, quant_bit)
self.assertException("quantization mode support 0 and 1.",
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 1, dtype, quant_mode, 4)
self.assertException("quantization bit width only supports 8.",
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *args)
default_args[-5] = 10240
self.assertException("max_context_len should smaller than or equal to block_size * max_block_num.",
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args)
default_args[-5] = max_context_len
default_args[1] = default_args[1].transpose(-1, -2)
self.assertException("Tensor value last dim must be contiguous.",
torch.ops.torch_mlu_ops.dequant_from_paged_cache, *default_args)
@unittest.skipIf("MLU3" in torch.mlu.get_device_name(),
"Skipping test_dequant_from_paged_cache in MLU300 series")
def test_dequant_from_paged_cache(self):
test_cases = 100
head_size_times = 16
mlu_name = torch.mlu.get_device_name()
batch_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
max_context_len_list = torch.randint(low=2, high=1024, size=[test_cases],
dtype=torch.int32)
head_num_q_list = torch.randint(low=1, high=64, size=[test_cases], dtype=torch.int32)
head_num_kv_list = torch.randint(low=1, high=8, size=[test_cases], dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=[test_cases],
dtype=torch.int32) * head_size_times
block_size_list = torch.randint(low=1, high=32, size=[test_cases], dtype=torch.int32)
cache_mem_len_list = torch.randint(low=512, high=1024, size=[test_cases],
dtype=torch.int32) * 2
quant_mode_list = np.random.choice([0, 1], test_cases)
quant_bit_list = np.random.choice([8], test_cases)
use_offset_list = np.random.choice([False, True], test_cases)
has_value_list = np.random.choice([False, True], test_cases)
dtype_list = [torch.half, torch.bfloat16]
dtype_list = np.random.choice(dtype_list, test_cases)
for i in range(test_cases):
batch_size = batch_size_list[i].item()
head_num_q = head_num_q_list[i].item()
max_context_len = max_context_len_list[i].item()
head_num_kv = head_num_kv_list[i].item()
head_size = head_size_list[i].item()
block_size = block_size_list[i].item()
cache_mem_len = cache_mem_len_list[i].item()
quant_mode = quant_mode_list[i]
quant_bit = quant_bit_list[i]
use_seq_offset = use_offset_list[i]
has_value = has_value_list[i]
dtype = dtype_list[i]
print("batch_size={}, head_num={}, head_size={}, max_context_len={}, quant_mode={}, "
"quant_bit={}, dtype={} testing...".format(batch_size, head_num_kv, head_size,
max_context_len, quant_mode, quant_bit, dtype))
torch.manual_seed(2766)
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, use_seq_offset, dtype, quant_mode,
quant_bit, has_value)
self.launch(*args)
@unittest.skipIf("MLU3" in torch.mlu.get_device_name(),
"Skipping test_dequant_from_paged_cache in MLU300 series")
def test_inductor(self):
batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len, \
head_size, block_size = 8, 1024, 16, 32, 2048, 128, 16
dtype, quant_mode, quant_bit = torch.float16, 0, 8
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 1, dtype, quant_mode, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args)
args = gen_args(batch_size, max_context_len, head_num_q, head_num_kv, cache_mem_len,
head_size, head_size, block_size, 0, dtype, quant_mode, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.dequant_from_paged_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestDequantFromPagedCache))

View File

@@ -0,0 +1,299 @@
import torch
import torch_mlu
from itertools import product
import torch_mlu_ops as ops
import random
import torch.testing._internal.optests as optests
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
dtype_dict = {
torch.half: "half",
torch.bfloat16: "bfloat16",
torch.float: "float",
}
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
class FakeTensorTest(TestCase):
def test_matmul(self):
with fake_tensor_mode:
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
trans_a_list = [True, False]
trans_b_list = [True, False]
dtype_list = [torch.half, torch.float]
args = product(trans_a_list, trans_b_list, dtype_list)
for trans_a, trans_b, dtype in args:
print(f"matmul... trans_a: {trans_a}, trans_b: {trans_b}, dtype: {dtype}")
shape_a, shape_b = (mat_m, mat_k), (mat_k, mat_n)
if trans_a:
shape_a = (mat_k, mat_m)
if trans_b:
shape_b = (mat_n, mat_k)
a = torch.randn(shape_a, dtype=dtype, device='mlu')
b = torch.randn(shape_b, dtype=dtype, device='mlu')
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
output = torch.ops.torch_mlu_ops.matmul(a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
self.assertEqual(output.shape, (mat_m, mat_n))
a8 = torch.randint(-128, 127, shape_a).to(torch.int8).mlu()
b8 = torch.randint(-128, 127, shape_b).to(torch.int8).mlu()
a_scale = 280.8
b_scale = 190.8
str_dtype = dtype_dict[dtype]
output = torch.ops.torch_mlu_ops.matmul(a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
self.assertEqual(output.shape, (mat_m, mat_n))
def test_weight_only_quant_matmul(self):
with fake_tensor_mode:
M, K, N, group_num = 2, 256, 32, 4
quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta = 8, 'none', True, 1., 0.8, 0.3
group_quant_list = [True, False]
trans_a_list = [True, False]
trans_b_list = [True, False]
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
args = product(group_quant_list, dtype_list, trans_a_list, trans_b_list)
for group_quant, dtype, trans_a, trans_b in args:
print(f"weight_only_quant_matmul... group_quant: {group_quant}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}")
a = torch.randn((M, K), dtype=dtype).mlu()
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
c = torch.randn(M, N, device="mlu", dtype=dtype)
bias = torch.randn(N, device="mlu", dtype=dtype)
group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype)
b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel"
b_scale = group_wise_scale if group_quant else None
gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float)
a_scale, a_zero, b_zero, c_zero = None, None, None, None
c_scale, gemm_output_zero = None, None
quant_algo, a_quant_layout = "weight_only", "quantize_none"
str_dtype = dtype_dict[dtype]
output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero,
b, b_scale, b_zero,
bias, c, c_scale, c_zero,
gemm_output_scale, gemm_output_zero,
str_dtype, None, quant_algo,
a_quant_layout, b_quant_layout,
quant_bit_size, act_mode, use_hp_active, act_coef,
alpha, beta, trans_a, trans_b,)
self.assertEqual(output.shape, (M, N))
def test_smooth_quant_matmul(self):
with fake_tensor_mode:
act_mode_list = ["none", "silu", "gelu"]
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
trans_a_list = [True, False]
trans_b_list = [True, False]
M, K, N = 2, 16, 32
quant_bit_size, use_hp_active, act_coef, alpha, beta = 8, True, 1., 0.8, 0.3
arg = product(act_mode_list, dtype_list, trans_a_list, trans_b_list)
for act_mode, dtype, trans_a, trans_b in arg:
print(f"smooth_quant_matmul... act_mode: {act_mode}, dtype: {dtype}, trans_a: {trans_a}, trans_b: {trans_b}")
a = torch.randint(-128, 127, (M, K), dtype=torch.int8).mlu()
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
c = None
bias = torch.randn(N, device="mlu", dtype=dtype)
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
a_zero, b_zero, c_zero = None, None, None
c_scale, gemm_output_scale, gemm_output_zero = None, None, None
quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel"
str_dtype = dtype_dict[dtype]
output = torch.ops.torch_mlu_ops.quant_matmul(a, a_scale, a_zero,
b, b_scale, b_zero,
bias, c, c_scale, c_zero,
gemm_output_scale, gemm_output_zero,
str_dtype, None, quant_algo,
a_quant_layout, b_quant_layout,
quant_bit_size, act_mode, use_hp_active, act_coef,
alpha, beta, trans_a, trans_b)
self.assertEqual(output.shape, (M, N))
def test_group_gemm(self):
with fake_tensor_mode:
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
idx_list = [False, True]
has_bias_list = [True, False]
args = product( dtype_list, idx_list, has_bias_list)
for data_type, idx_mode, has_bias in args:
print(f"group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}")
bs = batch * seq
token_topk = bs * topk
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
sorted_expert_id, indices = expert_id.sort()
gather_idx = indices // topk
gather_idx = gather_idx.to(torch.int32)
token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32)
a = torch.randn(bs, k, device="mlu", dtype=data_type)
if not idx_mode:
a = a[gather_idx]
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
a_scale = None
b_scale = None
bias = None
if has_bias:
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
gather_idx_ = gather_idx if idx_mode else None
output = torch.ops.torch_mlu_ops.group_gemm(a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, None, None, None, bs)
self.assertEqual(output.shape, (token_topk, n))
def test_smoothquant_group_gemm(self):
with fake_tensor_mode:
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
idx_list = [False, True]
has_bias_list = [True, False]
args = product( dtype_list, idx_list, has_bias_list)
for data_type, idx_mode, has_bias in args:
print(f"smoothquant_group_gemm... has_bias: {has_bias}, idx_mode: {idx_mode}, dtype: {data_type}")
bs = batch * seq
token_topk = bs * topk
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
sorted_expert_id, indices = expert_id.sort()
gather_idx = indices // topk
gather_idx = gather_idx.to(torch.int32)
token_count = torch.randint(0, token_topk, (experts_num,)).to(torch.int32)
a8 = torch.randint(-128, 127, (bs, k)).to(torch.int8).mlu()
b8 = torch.randint(-128, 127, (experts_num, n, k)).to(torch.int8).mlu()
a_scale = torch.randn(token_topk, dtype=torch.float32).mlu()
b_scale = torch.randn(experts_num, n, dtype=torch.float32).mlu()
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
if not idx_mode:
a8 = a8[gather_idx]
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
bias = None
if has_bias:
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
gather_idx_ = gather_idx if idx_mode else None
str_dtype = dtype_dict[data_type]
output = torch.ops.torch_mlu_ops.group_gemm(a8, b8, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, bias, str_dtype, None, None, bs)
self.assertEqual(output.shape, (token_topk, n))
def test_moe_expand_input(self):
with fake_tensor_mode:
token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 2048, 4096, 32, 8, 3, 20
dtype_list = [torch.half, torch.float, torch.int8, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half, torch.float, torch.int8]
for dtype in dtype_list:
print(f"moe_expand_input... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}, dtype: {dtype}")
input = torch.randn(token_num, hidden_size, device='mlu').to(dtype)
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu')
cusum_token_count = torch.zeros(expert_num + 1, dtype=torch.int32).mlu()
output=torch.ops.torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count, start_expert_id, expert_size)
self.assertEqual(output.shape, (token_num*topk, hidden_size))
def test_moe_gen_idx(self):
with fake_tensor_mode:
token_num, expert_num, topk = 2048, 32, 8
print(f"moe_gen_idx... token_num: {token_num}, expert_num: {expert_num}, topk: {topk}")
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
expand_idx, combine_idx, token_count, cusum_token_count =torch.ops.torch_mlu_ops.moe_gen_idx(expert_id, expert_num)
self.assertEqual(expand_idx.shape, (token_num * topk,))
self.assertEqual(combine_idx.shape, (token_num * topk,))
self.assertEqual(token_count.shape, (expert_num,))
self.assertEqual(cusum_token_count.shape, (expert_num + 1,))
def test_moe_combine_result(self):
with fake_tensor_mode:
has_bias_list = [True, False]
num_tokens, hidden_size, num_expert, topk, start_expert_id = 1, 2048, 8, 2, 0
expert_size_list = [5, 8]
dtype_list = [torch.half, torch.bfloat16] if torch_mlu.mlu.is_bf16_supported() else [torch.half]
args = product(has_bias_list, expert_size_list, dtype_list)
for has_bias, expert_size, dtype in args:
print(f"moe_combine_result... has_bias: {has_bias}, expert_size: {expert_size}, dtype: {dtype}")
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device='mlu')
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device='mlu')
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device='mlu')
bias = None
residual = None
cusum_token_count = None
if has_bias:
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device='mlu')
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device='mlu')
if has_bias or expert_size < num_expert:
cusum_token_count =torch.zeros(num_expert + 1, dtype=torch.int32).mlu()
output=torch.ops.torch_mlu_ops.moe_combine_result(input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias)
self.assertEqual(output.shape, (num_tokens, hidden_size))
def test_moe(self):
with fake_tensor_mode:
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 100):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(1024, 3072, 512)
inner_size = random.randrange(1024, 3072, 512)
expert_num = random.randint(1, 40)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
quant_mode = random.choice(["no_quant", "w4", "w8", "w4w8"])
quant_wise = random.choice([128, 256, 512])
data_type = random.choice([torch.bfloat16, torch.float16])
if not torch_mlu.mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_mode, quant_wise, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_mode: {quant_mode}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
# if expert_size == -1:
# expert_size = expert_num
if quant_mode == "no_quant":
w1 = torch.randn((expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
w2 = torch.randn((expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
w1_quant_flag = None
w2_quant_flag = None
elif quant_mode == "w4":
w1_quant_group = hidden_size // quant_wise
w2_quant_group = inner_size // quant_wise
w1_quant_flag = None
w2_quant_flag = None
w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size // 2), device="mlu", dtype=torch.int32).to(torch.int8)
w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size // 2), device="mlu", dtype=torch.int32).to(torch.int8)
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
elif quant_mode == "w8":
w1_quant_flag = None
w2_quant_flag = None
w1 = torch.randint(-128, 127, (expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=torch.int32).to(torch.int8)
w2 = torch.randint(-128, 127, (expert_num, hidden_size, inner_size), device="mlu", dtype=torch.int32).to(torch.int8)
w1_scale = torch.empty((expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w2_scale = torch.empty((expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
elif quant_mode == "w4w8":
w1_quant_group = hidden_size // quant_wise
w2_quant_group = inner_size // quant_wise
w1_quant_flag = random.choices([4,8], k=expert_num * w1_quant_group)
w2_quant_flag = random.choices([4,8], k=expert_num * w2_quant_group)
w1_count = (sum(w1_quant_flag) // 4) * (quant_wise // 2) * inner_size*(1+gated)
w2_count = (sum(w2_quant_flag) // 4) * (quant_wise // 2) * hidden_size
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
act_smooth = None if quant_mode == "no_quant" else torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
bias1, bias2 = None, None
output = torch.ops.torch_mlu_ops.fused_moe(hidden_states, router_logit, w1, w2, bias1, bias2, residual,
input_smooth, act_smooth, w1_scale, w2_scale, w1_quant_flag,
w2_quant_flag, topk, renormalize, gated, act_mode, 0,
0, 0)
self.assertEqual(output.shape, hidden_states.shape)
if __name__ == "__main__":
if torch.__version__ >= '2.3.0':
run_tests()

View File

@@ -0,0 +1,65 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
class TestFFNOp(BtTestCase):
def op_impl_base(self, *args):
input, w1, bias1, w2, bias2, w3, bias3, act_mode = args
up = F.linear(input, w1, bias1)
act = act_mode_dict[act_mode](up.float()).to(input.dtype)
if w3 is not None:
gate = F.linear(input, w3, bias3)
act = act * gate
output = F.linear(act, w2, bias2)
return output
def test_ffn(self):
input_size_list = [128, 256, 512]
hidden_size = 1024
seq_len_list = [10, 16, 20]
bool_value_list = [True, False]
batch = 5
dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for input_size, seq_len, bool_value in zip(input_size_list,
seq_len_list,
bool_value_list):
print("input_size={}, seq_len={}, bias={}, gated={}, testing...".format(
input_size, seq_len, bool_value, bool_value), flush=True)
use_gate = bool_value
for dtype in dtype_list:
w1 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
b1 = torch.randn((hidden_size), dtype=dtype, device="mlu")
w2 = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
b2 = torch.randn((input_size), dtype=dtype, device="mlu")
w3 = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu") if use_gate else None
b3 = torch.randn((hidden_size), dtype=dtype, device="mlu") if use_gate else None
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
args = (input, w1, b1, w2, b2, w3, b3, 'silu')
output = self.op_impl_base(*args)
tmo_output1 = ops.ffn(*args)
self.assertTensorsEqual(output.cpu().float(), tmo_output1.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
# use matmul to implement ffn
f1_weight = torch.cat((w1, w3), dim=0) if use_gate else w1
f1_bias = torch.cat((b1, b3), dim=0) if use_gate else b1
pre_gemm_out = ops.matmul(input.view(-1, input_size), f1_weight, f1_bias, None, "none", 1.0, 0)
act_out = ops.active(pre_gemm_out, 'silu', use_gate)
tmo_output2 = ops.matmul(act_out, w2, b2, None, 'none', 1.0, 0)
tmo_output2 = tmo_output2.view(batch, seq_len, input_size)
self.assertTensorsEqual(output.cpu().float(), tmo_output2.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_inductor(self):
batch, seq_len, input_size, hidden_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
input = torch.randn((batch, seq_len, input_size), dtype=dtype, device="mlu")
up_fc_weight = torch.randn((hidden_size, input_size), dtype=dtype, device="mlu")
down_proj_weight = torch.randn((input_size, hidden_size), dtype=dtype, device="mlu")
args = (input, up_fc_weight, None, down_proj_weight, None, None, None, None, None, act_mode, "none", 1e-5, 1., 0.)
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
if __name__ == '__main__':
exit(run_unittest(TestFFNOp))

View File

@@ -0,0 +1,357 @@
import math
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
def gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype):
batch = len(seq_q)
max_seq_q = max(seq_q)
max_seq_k = max(seq_k)
cu_seq_len_q = [0]
cu_seq_len_k = [0]
for i in range(batch):
cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1])
cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1])
softmax_scale = 1 / math.sqrt(head_size)
alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu()
attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu()
total_seq_q = sum(seq_q)
total_seq_k = sum(seq_k)
q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu")
block_tables = None
if use_block:
block_size = 16
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size))
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size
cache_shape = (num_blocks, head_num_k, block_size, head_size)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
k = torch.randn(size=cache_shape, dtype=dtype).mlu()
v = torch.randn(size=cache_shape, dtype=dtype).mlu()
else:
k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
v = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
tmo_out = torch.empty_like(q)
out_lse = torch.randn(batch, head_num_q, max_seq_q, dtype = torch.float, device="mlu") if return_lse else None
return (q, k, v, tmo_out, out_lse,
torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu"),
torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu"),
alibi_slope, attn_bias, None, None,
block_tables, max_seq_q, max_seq_k,
softmax_scale, is_causal,
-1, -1, "float", return_lse)
def gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype):
batch = len(seq_q)
max_seq_q = max(seq_q)
max_seq_k = max(seq_k)
cu_seq_len_q = [0]
cu_seq_len_k = [0]
for i in range(batch):
cu_seq_len_q.append(seq_q[i] + cu_seq_len_q[-1])
cu_seq_len_k.append(seq_k[i] + cu_seq_len_k[-1])
cu_seq_len_q = torch.tensor(cu_seq_len_q, dtype=torch.int32, device="mlu")
cu_seq_len_k = torch.tensor(cu_seq_len_k, dtype=torch.int32, device="mlu")
alibi_slope = None if has_alibi == False else torch.zeros((head_num_q)).uniform_(0, 0.1).to(torch.float32).mlu()
attn_bias = None if has_mask is False else torch.randn((batch, head_num_q, max_seq_q, max_seq_k), dtype=dtype).mlu()
total_seq_q = sum(seq_q)
total_seq_k = sum(seq_k)
q = torch.randn(total_seq_q, head_num_q, head_size, dtype=dtype, device="mlu")
block_tables = None
if use_block:
block_size = 16
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size))
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
else:
k = torch.randn(total_seq_k, head_num_k, head_size, dtype=dtype, device="mlu")
v = torch.randn(total_seq_k, head_num_k, head_size_v, dtype=dtype, device="mlu")
return q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables
class TestFlashAttnOp(BtTestCase):
def op_impl_base(self, *args):
q, k, v, out, cu_seq_lens_q, cu_seq_lens_kv, alibi_slope, attn_bias, max_seq_len_q, \
max_seq_len_kv, softmax_scale, is_causal, window_size_left, window_size_right, \
compute_dtype, return_lse, block_tables, k_cache_quant_scale, v_cache_quant_scale = args
is_pack = cu_seq_lens_q is not None
has_block_table = block_tables is not None
if has_block_table:
assert is_pack == True
batch = len(cu_seq_lens_q) - 1 if is_pack else q.size(0)
head_num_q = q.size(-2)
head_num_kv = k.size(-2) if block_tables is None else k.size(-3)
head_size = q.size(-1)
head_size_v = v.size(-1)
assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0
group = head_num_q // head_num_kv
device = q.device
repeat_dim = -3 if has_block_table else -2
k_bd = torch.repeat_interleave(k, group, dim=repeat_dim)
v_bd = torch.repeat_interleave(v, group, dim=repeat_dim)
out_list = []
inf = 1e6
lse = torch.zeros(batch, head_num_q, max_seq_len_q, dtype=torch.float)
for i in range(batch):
q_i = q[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...] if is_pack else q[i]
if not has_block_table:
k_i = k_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else k[i]
v_i = v_bd[cu_seq_lens_kv[i]:cu_seq_lens_kv[i+1], ...] if is_pack else v[i]
else:
block_table = block_tables[i]
context_len = cu_seq_lens_kv[i+1] - cu_seq_lens_kv[i]
block_size = k.size(-2) # (num_block, head_num, block_size, head_size)
table_end = (context_len + block_size - 1) // block_size
block_ids = block_table[0 : table_end]
keys, values = k[block_ids], v[block_ids]
keys = torch.repeat_interleave(keys, group, dim=1) #[num_block, head_q, block_size, head_size]
keys = keys.transpose(1, 0).contiguous().view(head_num_q, -1, head_size) #[head_q, num_block * blcke_size, head_size]
k_i = keys.transpose(1,0) #[num_block * blcke_size, head_q, head_size]
k_i = k_i[0:context_len, ...] #[seq_k, head_q, head_size]
values = torch.repeat_interleave(values, group, dim=1)
values = values.transpose(1, 0).contiguous().view(head_num_q, -1, head_size_v)
v_i = values.transpose(1,0)
v_i = v_i[0:context_len, ...]
qk = torch.einsum('qhd,khd->hqk', q_i, k_i).to(torch.float) * softmax_scale
seq_q, seq_k = q_i.size(0), k_i.size(0)
if alibi_slope is not None:
slope = alibi_slope.reshape(1, head_num_q, 1, 1)
slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device)
if is_causal:
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device)
slope_bias = relative_pos * slope
else:
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
col_idx = torch.arange(seq_k, dtype=torch.long)
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device)
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
qk += (slope_bias.squeeze(0))
if is_causal:
assert seq_q <= seq_k, "seq_q <= seq_k if causal=True"
zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu")
tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1)
mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k)
qk += mask
if window_size_left != -1 or window_size_right != -1:
mask_w = torch.full((seq_q, seq_k), -inf, dtype=torch.float, device="mlu")
for qi in range(seq_q):
left = max(seq_k - seq_q + qi - window_size_left, 0) if window_size_left != -1 else 0
right = min(max(seq_k - seq_q + qi + window_size_right + 1, 0), seq_k) if window_size_right != -1 else seq_k
mask_w[qi, left:right] = 0
qk += mask_w
if attn_bias is not None:
qk += attn_bias[i][:, :seq_q, :seq_k]
if return_lse:
lse[i][:, :seq_q] = torch.logsumexp(qk, dim=-1)
attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype)
qkv = torch.einsum('hqk,khd->qhd', attn, v_i)
out_list.append(qkv)
attn_out = torch.cat(out_list, dim=0)
if is_pack == False:
attn_out = attn_out.view(q.size(0), q.size(1), q.size(2), head_size_v)
if return_lse:
attn_out = (attn_out, lse)
return attn_out
def test_flash_attention(self):
seq_len_list = [((38, 64, 128), (38, 64, 128)), ((30, 40, 50), (60, 90, 120))]
head_num_list = [(32, 32), (32, 4)]
use_block_list = [False, True]
head_size_list = [(64, 64), (128, 512)]
alibi_list = [False, True]
mask_list = [False, True]
causal_list = [False, True]
dtype_list = [torch.half, torch.float]
window_size_list = [(-1, -1), (10, -1), (8, 8)]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
use_block_list = [False]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
args = product(seq_len_list, head_num_list, head_size_list, alibi_list, mask_list, causal_list, use_block_list, window_size_list, dtype_list)
for ((seq_q, seq_k), (head_num_q, head_num_k), (head_size, head_size_v), has_alibi, has_mask, is_causal, use_block,
(window_size_left, window_size_right), dtype) in args:
batch = len(seq_q)
print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v= {}, has_alibi={}, \
has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format(
batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block,
window_size_left, window_size_right, dtype), flush=True)
max_seq_q = max(seq_q)
max_seq_k = max(seq_k)
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype)
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
softmax_scale = 1 / math.sqrt(head_size)
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
if use_block: # test block_tables is [batch, 1]
block_size = max_seq_k
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, -1, -1, torch.float, False,
block_tables if use_block else None, None, None)
tmo_output = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, -1, -1, torch.float, False,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_flash_attention_lse(self):
seq_len_list =[((66, 77, 88), (77, 88, 99)), ((1024, 1024), (1024, 1024))]
head_num_q, head_num_k = 16, 8
head_size_list = [(64, 64), (16, 128)]
is_causal = True
has_alibi = True
has_mask = True
window_size_list = [(-1, -1), (10, -1)]
dtype_list = [torch.half, torch.float]
use_block_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
use_block_list = [False]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
args = product(seq_len_list, head_size_list, use_block_list, window_size_list, dtype_list)
for (seq_q, seq_k), (head_size, head_size_v), use_block, (window_size_left, window_size_right), dtype in args:
batch = len(seq_q)
print("batch={}, seq_lens_q={}, seq_lens_k={}, head_num_q={}, head_num_k={}, head_size={}, head_size_v={}, has_alibi={}, \
has_mask={}, is_causal={}, use_block={}, window_size_left={}, window_size_right={}, dtype={}, testing...".format(
batch, seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, is_causal, use_block,
window_size_left, window_size_right, dtype), flush=True)
max_seq_q = max(seq_q)
max_seq_k = max(seq_k)
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, dtype)
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
softmax_scale = 1 / math.sqrt(head_size)
torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, True,
block_tables if use_block else None, None, None)
tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, True,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True)
if use_block: #test block_table = [batch, 1]
block_size = max_seq_k
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=dtype).mlu()
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=dtype).mlu()
torch_output, torch_output_lse = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, True,
block_tables if use_block else None, None, None)
tmo_output, tmo_output_lse = ops.flash_attention(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, True,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_output_lse.cpu(), tmo_output_lse.cpu(), 0.0006, use_MSE=True, use_RAE=True)
def test_flash_attention_inplace(self):
seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v = (66, 77, 88), (77, 88, 99), 16, 8, 64, 64
use_block_list = [False, True]
softmax_scale = 1 / math.sqrt(head_size)
is_causal = True
has_alibi = True
has_mask = True
batch = len(seq_q)
max_seq_q = max(seq_q)
max_seq_k = max(seq_k)
cu_seq_len_q = [0]
cu_seq_len_k = [0]
window_size_left, window_size_right = -1, -1
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
use_block_list = [False]
total_seq_q = sum(seq_q)
for use_block in use_block_list:
print("test_flash_attention_inplace: use_block: {}, testing....".format(use_block))
params = gen_params(seq_q, seq_k, head_num_q, head_num_k, head_size, head_size_v, has_alibi, has_mask, use_block, torch.half)
q, k, v, cu_seq_len_q, cu_seq_len_k, alibi_slope, attn_bias, block_tables = params
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
tmo_output = torch.empty((total_seq_q, head_num_q, head_size_v), dtype=torch.half, device="mlu")
ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
if use_block: #test block_table = [batch, 1]
block_size = max_seq_k
num_blocks = (int)(batch * ((max_seq_k + block_size -1 )// block_size)) #batch
max_num_blocks_per_seq = (max_seq_k + block_size - 1) // block_size #1
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
k = torch.randn(num_blocks, head_num_k, block_size, head_size, dtype=torch.float16).mlu()
v = torch.randn(num_blocks, head_num_k, block_size, head_size_v, dtype=torch.float16).mlu()
torch_output = self.op_impl_base(q, k, v, None, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
ops.flash_attention(q, k, v, tmo_output, cu_seq_len_q, cu_seq_len_k,
alibi_slope, attn_bias, max_seq_q, max_seq_k, softmax_scale,
is_causal, window_size_left, window_size_right, torch.float, False,
block_tables if use_block else None, None, None)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
seq_q, seq_k, head_num_q, head_num_k, head_size, dtype = (66, 77, 88), (77, 88, 99), 16, 8, 64, torch.float16
use_block_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
use_block_list = [False]
alibi_list = [False, True]
mask_list = [False, True]
causal_list = [False, True]
return_lse_list = [False, True]
test_flags = product(use_block_list, alibi_list, mask_list, causal_list, return_lse_list)
for use_block, has_alibi, has_mask, is_causal, return_lse in test_flags:
print(f"==== use_block: {use_block}, has_alibi: {has_alibi}, has_mask: {has_mask}, is_causal: {is_causal} return_lse: {return_lse}====")
args = gen_args(seq_q, seq_k, head_num_q, head_num_k, head_size, has_alibi, has_mask, is_causal, use_block, return_lse, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.flash_attention, args)
if __name__ == '__main__':
exit(run_unittest(TestFlashAttnOp))

View File

@@ -0,0 +1,94 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.nn as nn
class TestFusedNormAttnProjOp(BtTestCase):
def op_impl_base(self, *args):
input, q_weight, q_bias, k_weight, k_bias, v_weight, v_bias, norm_weight, \
norm_bias, eps, out_layout, head_size, norm_out = args
input_size = input.size(-1)
layernorm = torch.nn.LayerNorm(input_size)
layernorm.eps = eps
layernorm.weight = nn.Parameter(norm_weight)
layernorm.bias = nn.Parameter(norm_bias)
layernorm_out = layernorm(input)
q_out = torch.matmul(layernorm_out, q_weight.permute(1, 0)) + q_bias
if k_weight is not None:
k_out = torch.matmul(layernorm_out, k_weight.permute(1, 0)) + k_bias
v_out = torch.matmul(layernorm_out, v_weight.permute(1, 0)) + v_bias
if out_layout == 'nhtc':
batch, seq, _ = input.shape
hidden_size_q = q_weight.size(0)
q_head = hidden_size_q // head_size
q_out = q_out.reshape(batch, seq, q_head, head_size).transpose(1, 2)
if k_weight is not None:
hidden_size_kv = k_weight.size(0)
kv_head = hidden_size_kv // head_size
k_out = k_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2)
v_out = v_out.reshape(batch, seq, kv_head, head_size).transpose(1, 2)
outs = (q_out,) if k_weight is None else (q_out, k_out, v_out,)
if norm_out is True:
outs += (layernorm_out,)
return outs
def test_attn_proj(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
N, T, input_size, hidden_size, head_size, eps, alpha, beta = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
N, T, input_size, hidden_size), flush=True)
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
residual = torch.randn(N, T, hidden_size, dtype=dtype, device="mlu")
weights = torch.chunk(weight, 3)
biass = torch.chunk(bias, 3)
# test pre_attn_proj
print("test pre_attn_proj...")
out_torch = self.op_impl_base(input,
weights[0], biass[0],
weights[1], biass[1],
weights[2], biass[2],
norm_weight, norm_bias,
eps, 'nthc', head_size, True)
out_tmo = ops.fused_norm_attention_project(input,
weights[0], biass[0],
weights[1], biass[1],
weights[2], biass[2],
norm_weight, norm_bias,
eps, 'nthc', head_size, True)
for o1, o2 in list(zip(out_torch, out_tmo)):
self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
N, T, input_size, hidden_size, head_size, eps, alpha, beta, dtype = 4, 16, 512, 768, 64, 1e-5, 0.5, 0.3, torch.half
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu")
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
weights = torch.chunk(weight, 3)
biass = torch.chunk(bias, 3)
args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2],
norm_weight, norm_bias, None, "nhtc", head_size, eps, alpha,
beta, True)
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
args = (input, weights[0], biass[0], weights[1], biass[1], weights[2], biass[2],
norm_weight, norm_bias, None, "nthc", head_size, eps, alpha,
beta, True)
self.base_opcheck(torch.ops.torch_mlu_ops.attention_project, args)
if __name__ == '__main__':
exit(run_unittest(TestFusedNormAttnProjOp))

View File

@@ -0,0 +1,89 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
class TestFusedNormResidualFFNoP(BtTestCase):
def op_impl_base(self, *args):
input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias, gate_up_proj_weight, \
gate_up_proj_bias, layernorm_weight, layernorm_bias, eps, act_mode, residual_is, \
alpha, beta = args
hidden_size = input.size(-1)
layernorm = torch.nn.LayerNorm(hidden_size)
layernorm.weight = torch.nn.Parameter(layernorm_weight)
layernorm.bias = torch.nn.Parameter(layernorm_bias)
layernorm.eps = eps
act = act_mode_dict[act_mode]
residual = input
norm_out = input
if layernorm_weight is not None:
norm_out = layernorm(input)
if residual_is == "normed_input":
residual = norm_out
up_fc_out = torch.matmul(norm_out, up_fc_weight.permute(1, 0)) + up_fc_bias
act_out = act(up_fc_out.float()).to(up_fc_out.dtype)
if gate_up_proj_weight is not None:
gate_up_proj_out = torch.matmul(norm_out, gate_up_proj_weight.permute(1, 0)) + gate_up_proj_bias
down_proj_out = torch.matmul(act_out * gate_up_proj_out, down_proj_weight.permute(1, 0)) + down_proj_bias
else:
down_proj_out = torch.matmul(act_out, down_proj_weight.permute(1, 0)) + down_proj_bias
if residual_is != 'none':
out = beta * residual + alpha * down_proj_out
else:
out = down_proj_out
return out
def test_ffn(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
batch, seq_len, hidden_size, inner_size, alpha, beta = 4, 16, 512, 512, 0.5, 0.3
eps, act_mode, residual_is = 1e-5, "silu", "input"
print(f"batch: {batch}, seq_len: {seq_len}, hidden_size: {hidden_size}, inner_size: {inner_size}, alpha: {alpha}, beta: {beta}, \
eps: {eps}, act_mode: {act_mode}, residual_is: {residual_is}, dtype: {dtype} testing...", flush=True)
input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu")
layernorm_weight = torch.randn(hidden_size, dtype=dtype, device="mlu")
layernorm_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu")
up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
up_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu")
gated_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
gated_fc_bias = torch.normal(0, 0.1, (inner_size,), dtype=dtype, device="mlu")
down_fc_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu")
down_fc_bias = torch.normal(0, 0.1, (hidden_size,), dtype=dtype, device="mlu")
torch_output = self.op_impl_base(input,
up_fc_weight, up_fc_bias,
down_fc_weight, down_fc_bias,
gated_fc_weight, gated_fc_bias,
layernorm_weight, layernorm_bias,
eps, act_mode, residual_is,
alpha, beta)
tmo_output = ops.fused_norm_residual_ffn(input,
up_fc_weight, up_fc_bias,
down_fc_weight, down_fc_bias,
gated_fc_weight, gated_fc_bias,
layernorm_weight, layernorm_bias,
eps, act_mode, residual_is,
alpha, beta)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_inductor(self):
batch, seq_len, hidden_size, inner_size, act_mode, dtype = 1, 10, 128, 1024, 'silu', torch.half
input = torch.randn((batch, seq_len, hidden_size), dtype=dtype, device="mlu")
layernorm_weight = torch.randn((hidden_size,), dtype=dtype, device="mlu")
layernorm_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu")
up_fc_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
gate_up_proj_weight = torch.randn((inner_size, hidden_size), dtype=dtype, device="mlu")
down_proj_weight = torch.randn((hidden_size, inner_size), dtype=dtype, device="mlu")
up_fc_bias = torch.randn((inner_size,), dtype=dtype, device="mlu")
gate_up_proj_bias = torch.randn((inner_size,), dtype=dtype, device="mlu")
down_proj_bias = torch.randn((hidden_size,), dtype=dtype, device="mlu")
args = (input, up_fc_weight, up_fc_bias, down_proj_weight, down_proj_bias,
gate_up_proj_weight, gate_up_proj_bias, layernorm_weight, layernorm_bias, act_mode, "normed_input", 1e-5, 1., 0.)
self.base_opcheck(torch.ops.torch_mlu_ops.ffn, args)
if __name__ == '__main__':
exit(run_unittest(TestFusedNormResidualFFNoP))

View File

@@ -0,0 +1,213 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch import Size
import os
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
eps = 0.001
class TestFuseLayerNormOp(BtTestCase):
def op_impl_base(self, *args):
x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args
layernorm = torch.nn.LayerNorm(x.size(-1))
layernorm.eps = eps
layernorm.weight = Parameter(gamma)
layernorm.bias = Parameter(beta)
x = x + bias if bias is not None else x
pro_input = x + residual if residual is not None else x
output = layernorm(pro_input)
if quant_scale is not None:
output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8)
if out is None:
if store_output_before_norm:
return (output, pro_input)
else:
return output
else:
out.copy_(output)
return out
def test_layernorm(self):
C = 128
input_shape = (8, 8, 6, C)
print("test layernorm...")
for dtype in dtype_list:
torch.manual_seed(1)
input = torch.randn(input_shape, device="mlu", dtype=dtype)
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
gamma = torch.randn(C, device="mlu", dtype=dtype)
beta = torch.randn(C, device="mlu", dtype=dtype)
tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma,
beta, None, eps, True, None, None, False)
torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.005, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.005, use_MSE=True)
# test output_inplace
tmo_output = torch.empty(input_shape, device="mlu", dtype=dtype)
torch_output = torch.empty_like(tmo_output)
ops.fused_layer_norm(input, residual, gamma, beta, None, eps, False, None, tmo_output, False)
self.op_impl_base(input, residual, gamma, beta, None, eps, False, None, torch_output, False)
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.003, use_MSE=True)
#test input_stride and output-continguous
inputs_shape = (8, 8, 10, C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input = inputs[:, :, 0:6, :]
tmo_out = ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, None, False)
torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
# has res
tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False)
torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True)
#res has stride
res_shape = (8, 8, 16, C)
res = torch.randn(res_shape, device="mlu", dtype=dtype)
residual = res[..., 0:6, :]
tmo_out, tmo_res = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, None, None, False)
torch_out, torch_res = self.op_impl_base(input, residual, gamma, beta, None, eps, True, None, None, False)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
self.assertTensorsEqual(tmo_res.cpu().float(), torch_res.cpu().float(), 0.003, use_MSE=True)
#output has different stride
outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu')
output = outputs[..., 0:6, :]
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, output, False)
torch_out = self.op_impl_base(input, None, gamma, beta, None, eps, False, None, None, False)
self.assertTensorsEqual(output.cpu().float(), torch_out.cpu().float(), 0.003, use_MSE=True)
def test_squant_layernorm(self):
C = 128
input_shape = (8, 8, 6, C)
print("test squant_layernorm...")
for dtype in dtype_list:
torch.manual_seed(1)
input = torch.randn(input_shape, device="mlu", dtype=dtype)
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30
gamma = torch.randn(C, device="mlu", dtype=dtype)
beta = torch.randn(C, device="mlu", dtype=dtype)
tmo_out_0, tmo_out_1 = ops.fused_layer_norm(input, residual, gamma, beta, None, eps, True, quant_scale, None, False)
torch_out_0, torch_out_1 = self.op_impl_base(input, residual, gamma, beta, None, eps, True, quant_scale, None, False)
self.assertTensorsEqual(tmo_out_0.cpu().float(), torch_out_0.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out_1.cpu().float(), 0.006, use_MSE=True)
def test_layernorm_stride(self):
C = 128
print("test layernorm stride input...")
for dtype in dtype_list:
gamma = torch.randn(C, device="mlu", dtype=dtype)
beta = torch.randn(C, device="mlu", dtype=dtype)
inputs_shape = (8, 8, 10, C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input = inputs[:, :, 0:6, :]
tmo_output = torch.empty(inputs.shape, device="mlu", dtype=dtype)
tmo_output = tmo_output.as_strided(input.shape, input.stride())
torch_out_0 = torch.empty_like(tmo_output)
self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, tmo_output, False)
self.assertTensorsEqual(tmo_output.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
# test assumption get wrong stride when dim = 1
input = inputs[:, :, 8:9, :]
input = input.as_strided(input.shape, (10240, 1280, 0, 1))
torch_out_0 = torch.empty_like(input)
self.op_impl_base(input, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input, None, gamma, beta, None, eps, False, None, input, False)
self.assertTensorsEqual(input.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
inputs_shape = (8, 8, 10, 2*C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input1 = inputs[:, :, 0:2, 0:C]
torch_out_0 = torch.empty_like(input1)
self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False)
self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
input2 = inputs[:, :, :, 0:C]
torch_out_0 = torch.empty_like(input2)
self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False)
self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
#test 3-dim input
inputs_shape = (8, 8, 2*C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input1 = inputs[:, 0:2, 0:C]
torch_out_0 = torch.empty_like(input1)
self.op_impl_base(input1, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input1, None, gamma, beta, None, eps, False, None, input1, False)
self.assertTensorsEqual(input1.cpu().float(), torch_out_0.cpu().float(), 0.003, use_MSE=True)
input2 = inputs[0:2, 0:2, 0:C]
torch_out_0 = torch.empty_like(input2)
self.op_impl_base(input2, None, gamma, beta, None, eps, False, None, torch_out_0, False)
ops.fused_layer_norm(input2, None, gamma, beta, None, eps, False, None, input2, False)
self.assertTensorsEqual(input2.cpu().float(), torch_out_0.cpu().float(), 0.0032, use_MSE=True)
# 防呆测试
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
print("test_prevent....")
func1 = ops.fused_layer_norm
batch, seq_len, hidden_size = 5, 12, 512
input = torch.randn(hidden_size, dtype=torch.half, device='mlu')
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu')
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
self.assertException("input.dim() >= 2.",
func1, input, None, gamma, beta, None, eps, False, None, None)
inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu')
input = inputs[..., ::2]
self.assertException("input last dim must be contiguous.",
func1, input, None, gamma, beta, None, eps, False, None, None)
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
self.assertException("layernorm mode need gamma and beta.",
func1, input, None, gamma, None, None, eps, False, None, None)
gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
beta = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
self.assertException("layernorm mode, gamma and beta size must be hidden_size.",
func1, input, None, gamma, beta, None, eps, False, None, None)
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
input = inputs[..., 6:9, :]
self.assertException("quant_out is not support when input has stride.",
func1, input, None, gamma, beta, None, eps, False, quant_scale, None)
inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu')
input = inputs[:, 6:9, ...]
self.assertException("check the strides of input.",
func1, input, None, gamma, beta, None, eps, False, None, input)
input = inputs[..., 6:9, :]
outputs = torch.randn(batch, 2*seq_len, 3, hidden_size, dtype=torch.half, device='mlu')
output = outputs[:, :seq_len, ...]
self.assertException("check the strides of output.",
func1, input, None, gamma, beta, None, eps, False, None, output)
def test_inductor(self):
print("test_inductor....")
batch, seq_len, hidden_size = 5, 12, 512
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu')
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
bias = torch.randn(hidden_size, dtype=torch.half, device='mlu')
residual_out = torch.empty_like(input)
output = torch.zeros(input.shape, dtype=torch.int8, device='mlu')
args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "layernorm", eps, True, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "layernorm", eps, False, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
output = torch.empty_like(input)
args = (input, output, residual, gamma, beta, bias, None, None, None, "layernorm", eps, False, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
if __name__ == '__main__':
exit(run_unittest(TestFuseLayerNormOp))

View File

@@ -0,0 +1,202 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch import Size
import os
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
eps = 0.001
class TestFuseRmsNormOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
x = create_tensor_from_dic(dic['x'])
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
gamma = None if dic['gamma']['data'] is None else create_tensor_from_dic(dic['gamma'])
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
eps = dic['eps']['data']
store_output_before_norm = dic['store_output_before_norm']['data']
quant_scale = None if dic['quant_scale']['data'] is None else create_tensor_from_dic(dic['quant_scale'])
out = None if dic['out']['data'] is None else create_tensor_from_dic(dic['out'])
dynamic_quant = dic['dynamic_quant']['data']
self.launch(x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant)
def launch(self, *args):
args = list(args)
base_out = None if args[-2] is None else torch.empty_like(args[-2])
base_input = args[0].clone() if args[-2] is not None and args[0] is args[-2] else args[0]
tmo_out = ops.fused_rms_norm(*args)
args[0] = base_input
args[-2] = base_out
torch_out = self.op_impl_base(*args)
if tmo_out.__class__ in (list, tuple):
for o1, o2 in zip(tmo_out, torch_out):
self.assertTensorsEqual(o1.cpu().float(), o2.cpu().float(), 0.007, use_MSE=True)
else:
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.007, use_MSE=True)
def op_impl_base(self, *args):
x, residual, gamma, beta, bias, eps, store_output_before_norm, quant_scale, out, dynamic_quant = args
x = x + bias if bias is not None else x
pro_input = x + residual if residual is not None else x
store_input = pro_input
pro_input = pro_input.to(torch.float32)
variance = pro_input.pow(2).mean(-1, keepdim=True)
pro_input = pro_input * torch.rsqrt(variance + eps)
if gamma.dtype in [torch.float16, torch.bfloat16]:
pro_input = pro_input.to(gamma.dtype)
output = gamma * pro_input
if quant_scale is not None:
output = (output * quant_scale).round().clamp(-128, 127).to(torch.int8)
if out is None:
if store_output_before_norm:
return (output, store_input)
else:
return output
else:
out.copy_(output)
return out
def test_rmsnorm(self):
C = 128
input_shape = (8, 8, 6, C)
print("test rmsnorm...")
for dtype in dtype_list:
torch.manual_seed(1)
input = torch.randn(input_shape, device="mlu", dtype=dtype)
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
gamma = torch.randn(C, device="mlu", dtype=dtype)
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
# test inplace output
output = torch.empty(input_shape, device="mlu", dtype=dtype)
self.launch(input, residual, gamma, None, None, eps, False, None, output, False)
inputs_shape = (8, 8, 10, C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input = inputs[:, :, 0:6, :]
self.launch(input, None, gamma, None, None, eps, False, None, None, False)
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
#res has stride
res_shape = (8, 8, 16, C)
res = torch.randn(res_shape, device="mlu", dtype=dtype)
residual = res[..., 0:6, :]
self.launch(input, residual, gamma, None, None, eps, True, None, None, False)
#output has different stride
outputs = torch.randn(8, 8, 12, C, dtype=dtype, device='mlu')
output = outputs[..., 0:6, :]
self.launch(input, residual, gamma, None, None, eps, False, None, output, False)
def test_squant_rmsnorm(self):
C = 128
input_shape = (8, 8, 6, C)
print("test squant_rmsnorm...")
for dtype in dtype_list:
torch.manual_seed(1)
input = torch.randn(input_shape, device="mlu", dtype=dtype)
residual = torch.randn(input_shape, device="mlu", dtype=dtype)
quant_scale = torch.randn(C, device="mlu", dtype=torch.float) * 30
gamma = torch.randn(C, device="mlu", dtype=dtype)
self.launch(input, residual, gamma, None, None, eps, True, quant_scale, None, False)
# test one output
self.launch(input, None, gamma, None, None, eps, False, quant_scale, None, False)
def test_rmsnorm_stride(self):
C = 128
print("test rmsnorm stride input...")
for dtype in dtype_list:
gamma = torch.randn(C, device="mlu", dtype=dtype)
inputs_shape = (8, 8, 10, C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input = inputs[:, :, 0:6, :]
output = torch.empty(inputs.shape, device="mlu", dtype=dtype)
output = output.as_strided(input.shape, input.stride())
self.launch(input, None, gamma, None, None, eps, False, None, output, False)
# test assumption get wrong stride when dim = 1
input = inputs[:, :, 0 :1, :]
input = input.as_strided(input.shape, (10240, 1280, 0, 1))
self.launch(input, None, gamma, None, None, eps, False, None, input, False)
inputs_shape = (8, 8, 10, 2*C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input1 = inputs[:, :, 0:2, 0:C]
self.launch(input1, None, gamma, None, None, eps, False, None, input1, False)
input2 = inputs[:, :, :, 0:C]
self.launch(input2, None, gamma, None, None, eps, False, None, input2, False)
#tset 3-dim input
inputs_shape = (8, 8, 2*C)
inputs = torch.randn(inputs_shape, device="mlu", dtype=dtype)
input1 = inputs[:, 0:2, 0:C]
self.launch(input1, None, gamma, None, None, eps, False, None, input1, False)
input2 = inputs[0:2, 0:2, 0:C]
self.launch(input2, None, gamma, None, None, eps, False, None, input2, False)
# 防呆测试
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func1 = ops.fused_rms_norm
batch, seq_len, hidden_size = 5, 12, 512
input = torch.randn(hidden_size, dtype=torch.half, device='mlu')
quant_scale = torch.randn(hidden_size, dtype=torch.half, device='mlu')
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
self.assertException("input.dim() >= 2.",
func1, input, None, gamma, None, None, eps, False, None, None)
inputs = torch.randn(batch, seq_len, 2*hidden_size, dtype=torch.half, device='mlu')
input = inputs[..., ::2]
self.assertException("input last dim must be contiguous.",
func1, input, None, gamma, None, None, eps, False, None, None)
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
self.assertException("rmsnorm mode need gamma.",
func1, input, None, None, None, None, eps, False, None, None)
gamma = torch.randn(2 * hidden_size, dtype=torch.half, device='mlu')
self.assertException("rmsnorm mode, gamma size must be hidden_size.",
func1, input, None, gamma, None, None, eps, False, None, None)
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
inputs = torch.randn(batch, seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
input = inputs[..., 6:9, :]
self.assertException("quant_out is not support when input has stride.",
func1, input, None, gamma, None, None, eps, False, quant_scale, None)
inputs = torch.randn(batch, 12, seq_len, hidden_size, dtype=torch.half, device='mlu')
input = inputs[:, 6:9, ...]
self.assertException("check the strides of input.",
func1, input, None, gamma, None, None, eps, False, None, input)
input = inputs[..., 6:9, :]
outputs = torch.randn(batch, 2 * seq_len, 12, hidden_size, dtype=torch.half, device='mlu')
output = outputs[:, 0:seq_len, 6:9, :]
self.assertException("check the strides of output.",
func1, input, None, gamma, None, None, eps, False, None, output)
def test_inductor(self):
batch, seq_len, hidden_size = 5, 12, 512
input = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
residual = torch.randn(batch, seq_len, hidden_size, dtype=torch.half, device='mlu')
quant_scale = torch.randn(hidden_size, dtype=torch.float, device='mlu')
gamma = torch.randn(hidden_size, dtype=torch.half, device='mlu')
beta = torch.randn(hidden_size, dtype=torch.half, device='mlu')
bias = torch.randn(hidden_size, dtype=torch.half, device='mlu')
residual_out = torch.empty_like(input)
output = torch.zeros(input.shape, dtype=torch.int8, device='mlu')
args = (input, output, residual, gamma, beta, bias, quant_scale, residual_out, None, "rmsnorm", eps, True, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
args = (input, output, residual, gamma, beta, bias, quant_scale, None, None, "rmsnorm", eps, False, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
output = torch.empty_like(input)
args = (input, output, residual, gamma, beta, bias, None, None, None, "rmsnorm", eps, False, False)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_layernorm, args)
if __name__ == '__main__':
run_unittest(TestFuseRmsNormOp)

View File

@@ -0,0 +1,447 @@
import torch
import unittest
from torch.nn import functional as F
import torch_mlu_ops as ops
from common_utils import *
import random
def genSlotMapping(batch, block_size):
output = []
for i in range(batch):
idx = random.randint(i * block_size, (i + 1) * block_size - 1)
output.append(idx)
return output
def rotate(x: torch.Tensor):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def quant(input: torch.Tensor):
input_fp32 = input.to(torch.float32)
max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True)
scale = max_value / 7
scaled_input = torch.round(input_fp32 / scale)
return scaled_input.to(torch.int8), scale[..., 0], input_fp32 / scale
class TestFusedRopeOp(BtTestCase):
def op_impl_base(self, *args):
qkv, k_cache_hp, v_cache_hp, sin_cache, cos_cache, position_id, gamma, beta, \
k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, cache_bs_id_lp, \
cache_seq_offsets_lp, k_scale_hp, v_scale_hp, k_scale_lp, v_scale_lp, slot_mapping_hp, \
slot_mapping_lp, eps = args
mixed_cache = k_cache_lp is not None and v_cache_lp is not None
qkv = qkv.to(torch.float32)
sin_cache = sin_cache.to(torch.float32)
cos_cache = cos_cache.to(torch.float32)
gamma = gamma.to(torch.float32)
beta = beta.to(torch.float32)
if k_scale_hp is not None:
k_scale_hp = 1 / k_scale_hp
if v_scale_hp is not None:
v_scale_hp = 1 / v_scale_hp
if not mixed_cache and k_scale_hp is None:
k_cache_hp = k_cache_hp.to(torch.float32)
v_cache_hp = v_cache_hp.to(torch.float32)
discrete_batch_hp = cache_bs_id_hp is not None
discrete_batch_lp = cache_bs_id_lp is not None
paged_cache_hp = slot_mapping_hp is not None
paged_cache_lp = slot_mapping_lp is not None
head_size = qkv.shape[-1]
rope_dim = head_size
batch_size = qkv.shape[0]
head_qkv = qkv.shape[-2]
kv_heads = k_cache_hp.shape[1]
q_heads = head_qkv - 2 * kv_heads
head_qk = q_heads + kv_heads
group_num = 1
group_size = head_size
if mixed_cache:
group_num = k_scale_lp.shape[-1]
group_size = int(head_size / group_num)
k_cache_lp_shape = list(k_cache_lp.shape)
k_cache_lp_shape[-1] *= 2
k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_shape)
v_cache_lp_shape = list(v_cache_lp.shape)
v_cache_lp_shape.append(2)
v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_shape)
qk = qkv[:, :, 0:q_heads + kv_heads].clone()
for i in range(batch_size):
qk_i = qk[i]
sin_cache_i = sin_cache[position_id[i]:position_id[i] + 1]
cos_cache_i = cos_cache[position_id[i]:position_id[i] + 1]
sin_cache_i = sin_cache_i[:1]
cos_cache_i = cos_cache_i[:1]
rot = rotate(qk_i)
qk_i[:] = rot * sin_cache_i.unsqueeze(1) + qk_i * cos_cache_i.unsqueeze(1)
q = qk[:, :, 0:q_heads]
k = qk[:, :, q_heads:head_qk].contiguous().reshape(batch_size, kv_heads, head_size)
qkv_q = qkv[:, :, 0:q_heads]
qkv_q[...] = q
shape_k = k.shape
k = torch.reshape(k, (-1, shape_k[-1]))
k_norm = F.layer_norm(k, (head_size,), gamma, beta, eps)
k_norm = torch.reshape(k_norm, shape_k)
k_out = k_norm.reshape(batch_size, 1, kv_heads, head_size)
v_out = qkv[:, :, head_qk:head_qkv].contiguous()
k_out_hp = k_out.clone()
v_out_hp = v_out.clone()
k_out_lp = None
v_out_lp = None
if k_scale_hp is not None and v_scale_hp is not None:
k_scale_hp = k_scale_hp.reshape(kv_heads, head_size)
v_scale_hp = v_scale_hp.reshape(kv_heads, head_size)
k_out_hp = (k_out * k_scale_hp).round().clamp(-128, 127).to(torch.int8)
v_out_hp = (v_out * v_scale_hp).round().clamp(-128, 127).to(torch.int8)
if paged_cache_hp:
block_size = k_cache_hp.shape[2]
for i in range(batch_size):
if slot_mapping_hp[i] >= 0:
block_id = torch.div(slot_mapping_hp[i], block_size, rounding_mode='floor')
block_offset = slot_mapping_hp[i] % block_size
k_cache_hp[block_id, :, block_offset, :] = k_out_hp[i]
v_cache_hp[block_id, :, block_offset, :] = v_out_hp[i]
else:
for i in range(batch_size):
key_i = k_out_hp[i].transpose(1, 0)
value_i = v_out_hp[i].transpose(1, 0)
cache_bs_id_hp_i = cache_bs_id_hp[i] if discrete_batch_hp else i
cache_seqlen_offset_hp_i = cache_seq_offsets_hp[i]
if cache_seqlen_offset_hp_i < 0 or cache_bs_id_hp_i < 0:
continue
key_cache_hp_i = \
k_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1]
key_cache_hp_i[...] = key_i[...]
value_cache_hp_i = \
v_cache_hp[cache_bs_id_hp_i, :, cache_seqlen_offset_hp_i:cache_seqlen_offset_hp_i + 1]
value_cache_hp_i[...] = value_i[...]
if mixed_cache:
for i in range(batch_size):
key_i = k_out[i].reshape(kv_heads, group_num, group_size)
value_i = v_out[i].reshape(kv_heads, group_num, group_size)
key_i_lp, key_scale_i_lp, _ = quant(key_i)
value_i_lp, value_scale_i_lp, scaled_input = quant(value_i)
if paged_cache_lp:
block_size = k_cache_lp_int8.shape[2]
if slot_mapping_lp[i] >= 0:
block_id = torch.div(slot_mapping_lp[i], block_size, rounding_mode='floor')
block_offset_k = slot_mapping_lp[i] % block_size
block_offset_v = torch.div(block_offset_k, 2, rounding_mode='floor')
odd_even_offset = block_offset_k % 2
k_cache_lp_int8[block_id, :, block_offset_k, :] = key_i_lp.reshape(kv_heads, head_size)
v_cache_lp_int8[block_id, :, block_offset_v, :, odd_even_offset] = value_i_lp.reshape(kv_heads, head_size)
k_scale_lp[block_id, :, block_offset_k, :] = key_scale_i_lp[...]
v_scale_lp[block_id, :, block_offset_k, :] = value_scale_i_lp[...]
else:
cache_bs_id_lp_i = cache_bs_id_lp[i] if discrete_batch_lp else i
cache_seqlen_offset_lp_i = cache_seq_offsets_lp[i]
v_seq_offset = torch.div(cache_seqlen_offset_lp_i, 2, rounding_mode='floor')
odd_even_offset = cache_seqlen_offset_lp_i % 2
if cache_seqlen_offset_lp_i < 0 or cache_bs_id_lp_i < 0:
continue
key_cache_lp_i = \
k_cache_lp_int8[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
key_cache_lp_i[...] = key_i_lp.reshape(kv_heads, head_size)
key_scale_lp_i = \
k_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
key_scale_lp_i[...] = key_scale_i_lp[...]
value_cache_lp_i = \
v_cache_lp_int8[cache_bs_id_lp_i, :, v_seq_offset, :, odd_even_offset]
value_cache_lp_i[...] = value_i_lp.reshape(kv_heads, head_size)
value_scale_lp_i = \
v_scale_lp[cache_bs_id_lp_i, :, cache_seqlen_offset_lp_i]
value_scale_lp_i[...] = value_scale_i_lp[...]
out = (qkv, k_cache_hp, v_cache_hp)
if mixed_cache:
k_cache_lp = PairlyPackInt8(k_cache_lp_int8.view(-1, head_size)).reshape(k_cache_lp.shape)
v_cache_lp_int8 = v_cache_lp_int8.transpose(2, 3)
s0,s1,s2,s3,s4 = v_cache_lp_int8.shape
v_cache_lp = PairlyPackInt8(v_cache_lp_int8.reshape(-1, s3 * s4)).reshape(s0, s1, s2, s3 * s4 // 2).transpose(2,3)
out += (k_cache_lp, v_cache_lp)
if k_scale_lp is not None:
out += (k_scale_lp, v_scale_lp)
return out
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device")
def test_fused_rope(self):
random_cases = 100
random.seed(355)
eps = 1e-6
for _ in range(random_cases):
if random_cases == 1:
bs = 512
seq_len = 1
q_heads = 8
kv_heads = 1
head_size = 128
rope_dim = 128
dtype = torch.float16
need_quant_kv = False
mixed_cache = False
max_decode_len_hp = 128
max_decode_len_lp = 128
discrete_batch_hp = True
discrete_batch_lp = False
paged_cache_hp = False
paged_cache_lp = True
block_size_hp = 16
block_size_lp = 128
num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp)
num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp)
group_num = 1
else:
bs = random.randint(1, 512)
seq_len = 1
q_heads = random.randint(1, 32)
kv_heads = random.randint(1, 32)
head_size_list = [32, 64, 96, 128]
head_size = random.choice(head_size_list)
rope_dim = head_size
dtype_list = [torch.half, torch.bfloat16]
bool_list = [True, False]
block_size_list = [16, 32, 64, 128]
dtype = random.choice(dtype_list)
need_quant_kv = random.choice(bool_list)
mixed_cache = random.choice(bool_list)
max_decode_len_hp = random.randint(128, 1024)
max_decode_len_lp = random.randint(128, 1024)
discrete_batch_hp = random.choice(bool_list)
discrete_batch_lp = random.choice(bool_list)
paged_cache_hp = random.choice(bool_list)
paged_cache_lp = random.choice(bool_list)
block_size_hp = random.choice(block_size_list)
num_blocks_hp = int((bs * max_decode_len_hp + block_size_hp - 1) / block_size_hp)
block_size_lp = random.choice(block_size_list)
num_blocks_lp = int((bs * max_decode_len_lp + block_size_lp - 1) / block_size_lp)
group_num_list = [1, 2, 4, 8]
group_num = random.choice(group_num_list)
if mixed_cache:
need_quant_kv = True
group_size = int(head_size / group_num)
if max_decode_len_lp % 2 != 0:
max_decode_len_lp = max_decode_len_lp - 1
print("bs: {}, seq_len: {}, q_heads: {}, kv_heads: {}, head_size: {}, rope_dim: {}, "
"dtype: {}, mixed_cache: {}, quant_kv: {}, paged_cache_hp: {}, paged_cache_lp: {}, "
"discrete_batch_hp: {}, discrete_batch_lp: {}."
.format(bs, seq_len, q_heads, kv_heads, head_size, rope_dim, dtype, mixed_cache,
need_quant_kv, paged_cache_hp, paged_cache_lp, discrete_batch_hp, discrete_batch_lp))
if mixed_cache:
print("max_decode_len_hp: {}, max_decode_len_lp: {}, num_blocks_hp: {}, num_blocks_lp: {}, "
"block_size_hp: {}, block_size_lp: {}, group_num: {}, testing..."
.format(max_decode_len_hp, max_decode_len_lp, num_blocks_hp, num_blocks_lp,
block_size_hp, block_size_lp, group_num))
max_bs_hp = bs + 1 if discrete_batch_hp else bs
max_bs_lp = bs + 1 if discrete_batch_lp else bs
cache_size = 1 if need_quant_kv else 2
cache_bytes_hp = num_blocks_hp * kv_heads * block_size_hp * head_size if paged_cache_hp else \
max_bs_hp * kv_heads * max_decode_len_hp * head_size
cache_bytes_hp = cache_bytes_hp * cache_size
cache_bytes_lp = 0
if mixed_cache:
cache_bytes_lp = num_blocks_lp * kv_heads * block_size_lp * head_size if paged_cache_lp else \
max_bs_lp * kv_heads * max_decode_len_lp * head_size
if cache_bytes_hp > 2**31 or cache_bytes_lp > 2**31:
print("cache bytes can not be larger than int32max. ")
continue
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
input = torch.randn(size=input_shape, dtype=dtype).mlu()
if paged_cache_hp:
cache_shape_hp = (num_blocks_hp, kv_heads, block_size_hp, head_size)
else:
cache_shape_hp = (max_bs_hp, kv_heads, max_decode_len_hp, head_size)
if need_quant_kv:
k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
k_cache_hp = (k_cache_hp - 0.5) * 256
k_cache_hp = k_cache_hp.to(torch.int8)
v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
v_cache_hp = (k_cache_hp - 0.5) * 256
v_cache_hp = k_cache_hp.to(torch.int8)
k_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01)
v_scale_ops_hp = 1 / (torch.randn(size=(kv_heads, head_size), dtype=torch.float).abs().mlu() + 0.01)
else:
k_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
v_cache_hp = torch.randn(cache_shape_hp, dtype=dtype, device='mlu')
k_scale_ops_hp = None
v_scale_ops_hp = None
k_cache_lp = None
v_cache_lp = None
k_scale_lp = None
v_scale_lp = None
if mixed_cache:
if paged_cache_lp:
k_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu()
v_scale_lp = torch.randn(size=(num_blocks_lp, kv_heads, block_size_lp, group_num), dtype=torch.float).mlu()
cache_raw = torch.randn((num_blocks_lp * kv_heads * block_size_lp, head_size), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(cache_raw))
cache_raw = cache_raw * (7 / max_value)
cache_raw = cache_raw.to(torch.int8)
k_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, block_size_lp, int(head_size / 2))
cache_raw = torch.randn((int(num_blocks_lp * kv_heads * block_size_lp / 2), head_size * 2), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(cache_raw))
cache_raw = cache_raw * (7 / max_value)
cache_raw = cache_raw.to(torch.int8)
v_cache_lp = PairlyPackInt8(cache_raw).reshape(num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size)
k_cache_lp_ref_shape = (num_blocks_lp, kv_heads, block_size_lp, head_size)
v_cache_lp_ref_shape = (num_blocks_lp, kv_heads, int(block_size_lp / 2), head_size, 2)
else:
k_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu()
v_scale_lp = torch.randn(size=(max_bs_lp, kv_heads, max_decode_len_lp, group_num), dtype=torch.float).mlu()
cache_raw = torch.randn((max_bs_lp * kv_heads * max_decode_len_lp, head_size), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(cache_raw))
cache_raw = cache_raw * (7 / max_value)
cache_raw = cache_raw.to(torch.int8)
k_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp, kv_heads, max_decode_len_lp, int(head_size / 2))
cache_raw = torch.randn((int(max_bs_lp * kv_heads * max_decode_len_lp / 2), head_size * 2), dtype=dtype, device='mlu')
max_value = torch.amax(torch.abs(cache_raw))
cache_raw = cache_raw * (7 / max_value)
cache_raw = cache_raw.to(torch.int8)
v_cache_lp = PairlyPackInt8(cache_raw).reshape(max_bs_lp,kv_heads, int(max_decode_len_lp / 2), head_size)
k_cache_lp_ref_shape = (max_bs_lp, kv_heads, max_decode_len_lp, head_size)
v_cache_lp_ref_shape = (max_bs_lp, kv_heads, int(max_decode_len_lp / 2), head_size, 2)
del cache_raw
cache_bs_id_hp = None
cache_bs_id_lp = None
cache_seq_offsets_hp = None
cache_seq_offsets_lp = None
slot_mapping_hp = None
slot_mapping_lp = None
if not paged_cache_hp:
if discrete_batch_hp:
cache_bs_id_hp = random.sample([*range(0, max_bs_hp)], bs)
cache_bs_id_hp = torch.IntTensor(cache_bs_id_hp).mlu()
cache_seq_offsets_hp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_hp - 2,
dtype=torch.int32, device='mlu')
else:
slot_mapping_hp = random.sample([*range(-1, block_size_hp * num_blocks_hp)], bs)
slot_mapping_hp = torch.IntTensor(slot_mapping_hp).mlu()
input_ref = input.clone()
k_cache_hp_ref = k_cache_hp.clone()
v_cache_hp_ref = v_cache_hp.clone()
k_cache_lp_ref = None
v_cache_lp_ref = None
k_scale_lp_ref = None
v_scale_lp_ref = None
if mixed_cache:
if not paged_cache_lp:
if discrete_batch_lp:
cache_bs_id_lp = random.sample([*range(0, max_bs_lp)], bs)
cache_bs_id_lp = torch.IntTensor(cache_bs_id_lp).mlu()
cache_seq_offsets_lp = torch.randint(size=(bs, ), low=-1, high=max_decode_len_lp - 2,
dtype=torch.int32, device='mlu')
else:
slot_mapping_lp = genSlotMapping(bs, block_size_lp)
slot_mapping_lp = torch.IntTensor(slot_mapping_lp).mlu()
k_cache_lp_ref = k_cache_lp.clone()
v_cache_lp_ref = v_cache_lp.clone()
k_scale_lp_ref = k_scale_lp.clone()
v_scale_lp_ref = v_scale_lp.clone()
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
base_output = self.op_impl_base(input_ref, k_cache_hp_ref, v_cache_hp_ref, sin_table, \
cos_table, position_id, gamma, beta, \
k_cache_lp_ref, v_cache_lp_ref, cache_bs_id_hp, cache_seq_offsets_hp, \
cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \
k_scale_lp_ref, v_scale_lp_ref, slot_mapping_hp, slot_mapping_lp, eps)
notify_end.record()
notify_end.synchronize()
time = notify_start.hardware_time(notify_end)
print("baseline hw_time is: ", time, "us")
del input_ref, k_cache_hp_ref, v_cache_hp_ref, k_cache_lp_ref, v_cache_lp_ref, k_scale_lp_ref, v_scale_lp_ref
notify_start.record()
loop = 1
for _ in range(loop):
ops.fused_rope(input, k_cache_hp, v_cache_hp, sin_table, cos_table, position_id, \
gamma, beta, k_cache_lp, v_cache_lp, cache_bs_id_hp, cache_seq_offsets_hp, \
cache_bs_id_lp, cache_seq_offsets_lp, k_scale_ops_hp, v_scale_ops_hp, \
k_scale_lp, v_scale_lp, slot_mapping_hp, slot_mapping_lp, eps)
notify_end.record()
notify_end.synchronize()
time = notify_start.hardware_time(notify_end) / loop
print("hw time is: ", time, "us")
print("check input diff \n")
self.assertTensorsEqual(input.cpu().float(), base_output[0].cpu().float(), 0.003, use_MSE=True)
print("pass \n")
print("check key cache hp diff \n")
self.assertTensorsEqual(k_cache_hp.cpu().float(), base_output[1].cpu().float(), 0.003, use_MSE=True)
print("pass \n")
print("check value cache hp diff \n")
self.assertTensorsEqual(v_cache_hp.cpu().float(), base_output[2].cpu().float(), 0.003, use_MSE=True)
print("pass \n")
if mixed_cache:
k_cache_lp_int8 = UnpackInt4(k_cache_lp).reshape(k_cache_lp_ref_shape)
v_cache_lp_int8 = UnpackInt4(v_cache_lp).reshape(v_cache_lp_ref_shape)
k_cache_lp_ref_int8 = UnpackInt4(base_output[3]).reshape(k_cache_lp_ref_shape)
v_cache_lp_ref_int8 = UnpackInt4(base_output[4]).reshape(v_cache_lp_ref_shape)
print("check key cache lp diff \n")
self.assertTensorsEqual(k_cache_lp_int8.cpu().float(), k_cache_lp_ref_int8.cpu().float(), 1)
print("pass \n")
print("check key scale lp diff \n")
self.assertTensorsEqual(k_scale_lp.cpu().float(), base_output[5].cpu().float(), 0.003, use_MSE=True)
print("pass \n")
print("check value cache lp diff \n")
self.assertTensorsEqual(v_cache_lp_int8.cpu().float(), v_cache_lp_ref_int8.cpu().float(), 1)
print("pass \n")
print("check value scale lp diff \n")
self.assertTensorsEqual(v_scale_lp.cpu().float(), base_output[6].cpu().float(), 0.003, use_MSE=True)
print("pass \n")
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "fused_rope not support MLU3XX device")
def test_inductor(self):
bs, seq_len, q_heads, kv_heads, head_size, rope_dim, max_decode_len, dtype= 40, 1, 8, 1, 128, 128, 2048, torch.bfloat16
input_shape = (bs, seq_len, q_heads + 2 * kv_heads, head_size)
input = torch.randn(size=input_shape, dtype=dtype).mlu()
max_bs = bs + 1
cos_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
sin_table = torch.randn(size=(bs, rope_dim), dtype=dtype).mlu()
gamma = torch.randn(size=(head_size, ), dtype=dtype).mlu()
beta = torch.randn(size=(head_size, ), dtype=dtype).mlu()
cache = torch.randn((2, max_bs, kv_heads, max_decode_len, head_size), dtype=dtype, device='mlu')
k_cache = cache[0]
v_cache = cache[1]
cache_bs_id = random.sample([*range(0, max_bs)], bs)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_seq_offsets = torch.randint(size=(bs, ), low=-1, high=max_decode_len - 2,
dtype=torch.int32, device='mlu')
position_id = torch.randint(size=(bs, ), low=0, high=bs, dtype=torch.int32, device='mlu')
args = (input, k_cache, v_cache, None, None, sin_table, cos_table, position_id, gamma, beta, None,
None, None, None, cache_bs_id, cache_seq_offsets, None, None, None, None, 1e-5)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_rope, args)
if __name__ == "__main__":
exit(run_unittest(TestFusedRopeOp))

View File

@@ -0,0 +1,122 @@
import torch
from torch_mlu import mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
from torch.nn import functional as F
def gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False):
bs = batch * seq
token_topk = bs * topk
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
sorted_expert_id, indices = expert_id.sort()
gather_idx = indices // topk
gather_idx = gather_idx.to(torch.int32)
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
a = torch.randn(bs, k, device="mlu", dtype=data_type)
if not idx_mode:
a = a[gather_idx]
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
alpha = torch.randn(experts_num, device="mlu", dtype=torch.float32)
beta = torch.randn(experts_num, device="mlu", dtype=torch.float32)
a_scale = None
b_scale = None
bias = None
if has_bias:
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type)
gather_idx_ = gather_idx if idx_mode else None
return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bias
class TestGroupGemmOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
a = create_tensor_from_dic(dic['a'])
b = create_tensor_from_dic(dic['b'])
m_list = dic['m_list']['data']
expand_idx = dic['expand_idx']['data']
c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c'])
alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha'])
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
max_m = dic['max_m']['data']
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
self.launch(a, b, m_list, expand_idx, c, alpha, beta, max_m, bias)
def launch(self, *args):
total_m = args[2].sum().item()
torch_out = self.op_impl_base(*args)
tmo_out = ops.group_gemm(*args)
self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True)
def op_impl_base(self, *args):
a, b, m_list, expand_idx, c, alpha, beta, max_m, bias = args
a = a.reshape(-1, a.size(-1))
if expand_idx is not None:
a = a[expand_idx]
total_m = m_list.sum().item()
a_list = a[:total_m].split(tuple(m_list))
c_list = []
if c is not None:
c = c.reshape(-1, c.size(-1))
c_list = c[:total_m].split(tuple(m_list))
output_list = []
for i in range(b.size(0)): # alpha*(a*b) + bias + beta*c
if (a_list[i].size(0) > 0):
gemm_out = torch.matmul(a_list[i], b[i].permute(1, 0))
if alpha is not None:
gemm_out *= alpha[i]
if bias is not None:
gemm_out += bias[i]
if beta is not None and c_list != []:
gemm_out += c_list[i] * beta[i]
output_list.append(gemm_out)
real_res = torch.cat(output_list, dim=0)
output = torch.empty(a.shape[0], b.shape[1], device=real_res.device).to(real_res.dtype)
output[:total_m] = real_res
return output
def test_group_gemm(self):
bs_list = [1, 3]
seq_list = [5, 8]
k_list = [512, 1024]
n_list = [512, 768, 2048]
expert_list = [8, 32]
topk_list = [2, 5]
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
has_bias_list = [True, False]
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list)
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args:
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \
dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True)
param = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
torch_out = self.op_impl_base(*param[:7], batch * seq, param[10])
tmo_out = ops.group_gemm(*param[:7], batch * seq, param[10])
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_inductor(self):
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
has_bias_list = [True, False]
args = product( dtype_list, idx_list, has_bias_list)
for data_type, idx_mode, has_bias in args:
args = gen_params(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
args = list(args)
args[-2] = args[-1] # bias
args[-1] = None #dtype
args.extend([None, None, batch * seq]) #b_offset, max_m
self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, args)
if __name__ == '__main__':
exit(run_unittest(TestGroupGemmOp))

View File

@@ -0,0 +1,193 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
import numpy as np
class TestMatMulOp(BtTestCase):
def op_impl_base(self, *args):
a, b, bias, c, act_mode, alpha, beta, fast_act, approximate, d, \
a_scale, b_scale, trans_a, trans_b = args
if a_scale is not None:
a = a / a_scale
if b_scale is not None:
b = b / b_scale
if trans_a:
a = a.transpose(0, 1)
if trans_b:
b = b.transpose(0, 1)
mul_out = alpha * torch.matmul(a, b)
if bias is not None:
mul_out += bias
if c is not None:
mul_out += beta * c
if act_mode in act_mode_dict.keys():
active = act_mode_dict[act_mode]
mul_out = active(mul_out.float()).to(a.dtype)
return mul_out
def test_matmul(self):
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [128]
has_res_list = [ False, True]
has_bias_list = [True, False]
trans_a_list = [False, True]
trans_b_list = [False, True]
act_mode_list = ['none', 'relu', 'gelu', 'silu']
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
alpha = 0.625
beta = 1.0
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
torch.manual_seed(1)
print("m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
if has_res :
beta = 1.0
else :
beta = 0.
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
input = input0[:, 1, :]
weight = weight0[:, 0, :]
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
output = self.op_impl_base(input,
weight,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
tmo_output = ops.matmul(input,
weight,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
tmo_output_contiguous = ops.matmul(input.contiguous(), weight.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, 1.0, 1.0, trans_a, trans_b)
if act_mode == 'gelu':
tmo_output_high = ops.matmul(input.contiguous(), weight.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, True, None, 1.0, 1.0, trans_a, trans_b)
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
# @unittest.skip("not test")
def test_matmul_int8(self):
mat_m_list = [32]
mat_n_list = [256]
mat_k_list = [128]
has_res_list = [True, False]
has_bias_list = [True, False]
trans_a_list = [True, False]
trans_b_list = [True, False]
act_mode_list = ['none', 'relu', 'silu', 'gelu']
dtype_list = [torch.half, torch.float]
alpha = 0.625
beta = 1.0
args = product( mat_m_list, mat_n_list, mat_k_list, has_bias_list, has_res_list, act_mode_list, dtype_list, trans_a_list, trans_b_list)
for mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b in args:
print("int8 test: m={}, n={}, k={}, has_bias={}, has_res={}, act_mode={}, dtype={}, trans_a={}, trans_b={} testing...".format(
mat_m, mat_n, mat_k, has_bias, has_res, act_mode, dtype, trans_a, trans_b), flush=True)
torch.manual_seed(1)
if has_res :
beta = 1.0
else :
beta = 0.
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
input = input0[:, 1, :]
weight = weight0[:, 0, :]
input8, a_scale = QuantByTensor(input, 8)
weight8, b_scale = QuantByTensor(weight, 8)
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
residual = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
output = self.op_impl_base(input8, weight8,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, None, a_scale, b_scale, trans_a, trans_b)
tmo_output = ops.matmul(input8, weight8,
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
tmo_output_contiguous = ops.matmul(input8.contiguous(), weight8.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, False, dtype, a_scale, b_scale, trans_a, trans_b)
self.assertTensorsEqual(tmo_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(tmo_output_contiguous.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
if act_mode == 'gelu':
tmo_output_high = ops.matmul(input8.contiguous(), weight8.contiguous(),
alpha * bias if has_bias else None,
residual if has_res else None,
act_mode, alpha, beta, False, True, dtype, a_scale, b_scale, trans_a, trans_b)
self.assertTensorsEqual(tmo_output_high.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
mat_m, mat_n, mat_k, alpha, beta, act_mode, fast_act, approximate = 32, 256, 128, 0.8, 0.3, 'silu', True, True
trans_a_list = [True, False]
trans_b_list = [True, False]
dtype_list = [torch.half, torch.float]
args = product(trans_a_list, trans_b_list, dtype_list)
for trans_a, trans_b, dtype in args:
print("trans_a: {}, trans_b: {}, dtype: {}".format(trans_a, trans_b, dtype))
shape_a, shape_b = (mat_m, 4, mat_k), (mat_k, 3, mat_n)
if trans_a:
shape_a = (mat_k, 4, mat_m)
if trans_b:
shape_b = (mat_n, 3, mat_k)
input0 = torch.randn(shape_a, dtype=dtype, device='mlu')
weight0 = torch.randn(shape_b, dtype=dtype, device='mlu')
a = input0[:, 1, :]
b = weight0[:, 0, :]
bias = torch.randn((mat_n), dtype=dtype, device='mlu')
c = torch.randn((mat_m, mat_n), dtype=dtype, device='mlu')
args = (a, b, None, bias, c, None, act_mode, alpha, beta, fast_act, approximate, 1.0, 1.0, trans_a, trans_b)
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
a8, a_scale = QuantByTensor(a, 8)
b8, b_scale = QuantByTensor(b, 8)
str_dtype = "half"
if dtype == torch.float:
str_dtype = "float"
elif dtype == torch.bfloat16:
str_dtype = "bfloat16"
args = (a8, b8, None, bias, c, str_dtype, act_mode, alpha, beta, fast_act, approximate, a_scale, b_scale, trans_a, trans_b)
self.base_opcheck(torch.ops.torch_mlu_ops.matmul, args)
if __name__ == '__main__':
exit(run_unittest(TestMatMulOp))

View File

@@ -0,0 +1,946 @@
import torch
import unittest
import torch_mlu_ops as ops
from common_utils import *
from typing import Union, List, Tuple
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch_mlu import mlu
class TestFusedMOEOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
hidden_states = create_tensor_from_dic(dic['hidden_states'])
gating_output = create_tensor_from_dic(dic['gating_output'])
w1 = create_tensor_from_dic(dic['w1'])
w2 = create_tensor_from_dic(dic['w2'])
if dic['w1']['dtype'] is not torch.int8:
w1 *= 0.1
w2 *= 0.1
bias1 = None if dic['bias1']['data'] is None else create_tensor_from_dic(dic['bias1'])
bias2 = None if dic['bias2']['data'] is None else create_tensor_from_dic(dic['bias2'])
residual = None if dic['residual']['data'] is None else create_tensor_from_dic(dic['residual'])
input_smooth = None if dic['input_smooth']['data'] is None else create_tensor_from_dic(dic['input_smooth'], is_uniform=True, low=0.01, high=0.05)
act_smooth = None if dic['act_smooth']['data'] is None else create_tensor_from_dic(dic['act_smooth'], is_uniform=True, low=0.01, high=0.05)
w1_scale = None if dic['w1_scale']['data'] is None else create_tensor_from_dic(dic['w1_scale'], is_uniform=True, low=-0.05, high=0.05)
w2_scale = None if dic['w2_scale']['data'] is None else create_tensor_from_dic(dic['w2_scale'], is_uniform=True, low=-0.05, high=0.05)
topk = dic['topk']['data']
renormalize = dic['renormalize']['data']
gated = dic['gated']['data']
act_mode = dic['act_mode']['data']
start_expert_id = dic['start_expert_id']['data']
block_n = dic['block_n']['data']
cncl_comm = dic['cncl_comm']['data']
w1_quant_flag = dic['w1_quant_flag']['data']
w2_quant_flag = dic['w2_quant_flag']['data']
self.launch(hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth,
act_smooth, w1_scale, w2_scale, topk, renormalize, gated, act_mode,
start_expert_id, block_n, 0, w1_quant_flag, w2_quant_flag)
def launch(self, *args):
base = self.op_impl_base(*args)
tmo_res = tmo.fused_moe(*args)
self.assertTensorsEqual(tmo_res.cpu().float(), base.cpu().float(), 0.03, use_MSE=True)
def op_impl_base(self, *args):
hidden_states, gating_output, w1, w2, bias1, bias2, residual, input_smooth, \
act_smooth, w1_scale, w2_scale, topk, renormalize, \
gated, act_mode, start_expert_id, block_n, cncl_comm, w1_quant_flag, w2_quant_flag = args
if w2.dim() == 4:
w2 = w2.transpose(1, 2).reshape(-1, w2.size(1), w2.size(-1))
expert_num = gating_output.size(-1)
expert_size = w1.size(0) if w1_quant_flag is None else w1_scale.size(1)
def router(hidden_states, router_logit):
router_logit = torch.softmax(router_logit.view(-1, router_logit.size(-1)), dim=1)
topk_logit, expert_id = torch.topk(router_logit, k=topk, dim=1)
if renormalize:
topk_logit = topk_logit / topk_logit.sum(-1).unsqueeze(1)
sorted_expert_id, indices = expert_id.int().flatten().sort()
nk = indices.size(0)
token_cout = torch.bincount(sorted_expert_id, minlength=expert_num).cpu()
expand_idx = indices.int() // topk
combine_idx = torch.zeros((nk,), dtype=torch.int, device="mlu")
combine_idx.scatter_(0, indices, torch.arange(nk, dtype=torch.int, device="mlu"))
input_expand = hidden_states[expand_idx]
input_list = input_expand.split(tuple(token_cout))
input_scale_list = []
input_split_result = []
if input_smooth is not None:
# do smooth quant on input
idx = 0
for i in range(expert_num):
if i >= start_expert_id and i < start_expert_id + expert_size:
if (input_list[i].size(0) > 0):
temp = input_list[i] * input_smooth[idx]
input_split_result.append(temp)
idx += 1
else:
input_split_result.append(input_list[i])
quant_input, input_scale = QuantByRow(torch.cat(input_split_result, dim=0), 8)
input_list = quant_input.split(token_cout.tolist())
input_scale_list = input_scale.split(token_cout.tolist())
return input_list, input_scale_list, topk_logit.flatten().view(nk , 1), combine_idx
dtype = hidden_states.dtype
if w1_quant_flag is None:
inner_size = w1.size(1) // 2 if gated else w1.size(1)
else:
inner_size = w1_scale.size(2) // 2 if gated else w1_scale.size(2)
hidden_size = w2.size(1) if w2_quant_flag is None else w2_scale.size(2)
input = hidden_states.view(-1, hidden_states.size(-1))
gating_output = gating_output.view(-1, gating_output.size(-1))
input_list, input_scale_list, reduce_weight, combine_idx = router(input, gating_output)
output_list = []
idx = 0
need_quant = len(input_scale_list) != 0
if need_quant and w1_scale.dim() == 3:
w1_scale = w1_scale.transpose(0, 1).contiguous()
w2_scale = w2_scale.transpose(0, 1).contiguous()
if w1_quant_flag is not None:
w1_quant_group = w1_scale.size(1)
quant_wise = hidden_size // w1_quant_group
w1_quant_flag = torch.tensor(w1_quant_flag).view(-1, w1_quant_group)
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
if w2_quant_flag is not None:
w2_quant_group = w2_scale.size(1)
quant_wise = inner_size // w2_quant_group
w2_quant_flag = torch.tensor(w2_quant_flag).view(-1, w2_quant_group)
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
for i in range(expert_num): # start_expert_id : start_expert_id + expert_size
if i >= start_expert_id and i < start_expert_id + expert_size:
if (input_list[i].size(0) > 0):
if need_quant:
if w1_quant_flag is None:
gemm1_out = smooth_quant_matmul(input_list[i], input_scale_list[i],
w1[idx], w1_scale[idx], dtype)
else:
gemm1_out = smooth_quant_matmul_w4w8_mixed(input_list[i], input_scale_list[i],
w1[w1_offset_cu[idx]:w1_offset_cu[idx+1]], w1_scale[idx], dtype,
quant_flag = w1_quant_flag[idx])
else:
gemm1_out = torch.matmul(input_list[i], w1[idx].permute(1, 0))
act_in = gemm1_out[:, :inner_size].float()
gate = gemm1_out[:, inner_size:]
act = act_mode_dict[act_mode]
gemm1_out = act(act_in).to(dtype=dtype) if gated == False else \
act(act_in).to(dtype=dtype) * gate
if need_quant:
quant_gemm1_out, gemm1_out_scale = QuantByRow(gemm1_out * act_smooth[idx], 8)
if w2_quant_flag is None:
gemm2_out = smooth_quant_matmul(quant_gemm1_out, gemm1_out_scale, w2[idx], w2_scale[idx], dtype)
else:
gemm2_out = smooth_quant_matmul_w4w8_mixed(quant_gemm1_out, gemm1_out_scale,
w2[w2_offset_cu[idx]:w2_offset_cu[idx+1]], w2_scale[idx], dtype,
quant_flag = w2_quant_flag[idx])
else:
gemm2_out = torch.matmul(gemm1_out, w2[idx].permute(1, 0))
output_list.append(gemm2_out)
idx += 1
else:
output_list.append(torch.zeros_like(input_list[i]))
output = torch.cat(output_list, dim=0)[combine_idx].float() * reduce_weight
output = output.reshape(-1, topk, hidden_size).sum(dim=1).to(dtype=dtype)
if residual is not None:
output = output + residual.view(input.shape)
return output.view(hidden_states.shape)
def test_fused_moe(self):
print("test_fused_moe")
batch, seq, hidden_size, inner_size = 3, 5, 512, 768
dtype_list = [torch.half]
if mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', dtype
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
start_expert_id_list = [0, 1, 3, 4, 5, 6]
expert_size_list = [8, 4, 3, 2, 3, 2]
for i in range(len(start_expert_id_list)):
start_expert_id = start_expert_id_list[i]
expert_size = expert_size_list[i]
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
start_expert_id: {start_expert_id}, expert_size: {expert_size}, topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
torch_out = self.op_impl_base(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
# (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id).reshape(batch, seq, hidden_size)
tmo_out_3 = fused_moe(hidden_states,
router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_3.cpu().float(), 0.006, use_MSE=True)
def test_pertoken_quant_fused_moe_tp(self):
print("test_pertoken_quant_fused_moe_tp")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
self.__run_sq_case(batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
def test_moe_tp2_mixed_ep4_no_quant(self):
print("test_moe_tp2_mixed_ep4_no_quant")
tp_num = 2
ep_num = 4
expert_num = 32
expert_size = expert_num // ep_num
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
w1 = weight1.reshape(ep_num * expert_size, tp_num, (inner_size * (1 + gated)) // tp_num, hidden_size)
w2 = weight2.reshape(ep_num * expert_size, hidden_size, tp_num, inner_size // tp_num)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out1 = torch.zeros_like(hidden_states)
tmo_out2 = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
w1_curr_tp = w1[:, tp_idx, ...]
w2_curr_tp = w2[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
tmo_out1 += ops.fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out2 += fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
torch_out += self.op_impl_base(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
None,
None,
None,
None,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
tmo_out2 = tmo_out2.reshape(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out1.cpu().float(), tmo_out2.cpu().float(), 0.006, use_MSE=True)
def test_moe_tp2_mixed_ep4_quant(self):
print("test_moe_tp2_mixed_ep4_quant")
tp_num = 2
ep_num = 4
expert_num = 32
expert_size = expert_num // ep_num
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
topk, gated, renormalize, act_mode, data_type = 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
scale_s = 0.1 # avoid the occurrence of inf
eps = 0.01 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
residual = None
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
input_smooth = torch.randn(expert_num, hidden_size, device='mlu', dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, (1+gated)*inner_size, device='mlu', dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1_shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2_shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1.shape), quant_w2.view(weight2.shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = torch.zeros_like(hidden_states)
for _ in range(tp_num):
w1_scale_tp = w1_scale.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
act_smooth_tp = act_smooth.reshape(expert_num, tp_num, (1+gated)*inner_size//tp_num)
w1 = quant_w1.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num, hidden_size)
w2 = quant_w2.reshape(ep_num, expert_size, hidden_size, tp_num, inner_size*(1+gated)//tp_num)
input_smooth = input_smooth.reshape(ep_num, expert_size, hidden_size)
act_smooth = act_smooth.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
w1_scale = w1_scale.reshape(ep_num, expert_size, tp_num, inner_size*(1+gated)//tp_num)
w2_scale = w2_scale.reshape(ep_num, expert_size, hidden_size)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out = torch.zeros_like(hidden_states)
tmo_out1 = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
w1_curr_tp = w1[:, :, tp_idx, ...]
w2_curr_tp = w2[:, :, :, tp_idx, :]
act_smooth_tp = act_smooth[:, :, tp_idx, :]
w1_scale_tp = w1_scale[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp[ep_idx].contiguous()
input_smooth_curr_ep = input_smooth[ep_idx].contiguous()
act_smooth_curr_ep = act_smooth_tp[ep_idx].contiguous()
w1_scale_curr_ep = w1_scale_tp[ep_idx].contiguous()
w2_scale_curr_ep = w2_scale[ep_idx].contiguous()
tmo_out += ops.fused_moe(hidden_states, # [batch, seq, hidden_size]
router_logit, # [batch, seq, expert_num]
w1_curr_tp_and_ep, # [expert_size, inner_size*(1+gated)//tp_num, hidden_size]
w2_curr_tp_and_ep, # [expert_size, hidden_size, inner_size*(1+gated)//tp_num]
bias1,
bias2,
residual,
input_smooth_curr_ep, # [expert_size, hidden_size]
act_smooth_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
w1_scale_curr_ep, # [expert_size, inner_size*(1+gated)//tp_num]
w2_scale_curr_ep, # [expert_size, hidden_size]
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out1 += fused_moe(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
input_smooth_curr_ep,
act_smooth_curr_ep,
w1_scale_curr_ep,
w2_scale_curr_ep,
topk,
renormalize,
gated,
act_mode,
start_expert_id)
torch_out += self.op_impl_base(hidden_states,
router_logit,
w1_curr_tp_and_ep,
w2_curr_tp_and_ep,
bias1,
bias2,
residual,
input_smooth_curr_ep,
act_smooth_curr_ep,
w1_scale_curr_ep,
w2_scale_curr_ep,
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
self.assertTensorsEqual(tmo_out1.cpu().float(), torch_out.cpu().float(), 0.02, use_MSE=True)
def test_smq_fused_moe_random_tp(self):
print("test_smq_fused_moe_random_tp")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
for i in range(10):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 2048, 2)
inner_size = random.randrange(512, 2048, 2)
expert_num = random.randint(1, 40)
topk = random.randint(1,expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
self.__run_sq_case(*case)
def __run_sq_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, act_mode='gelu', start_expert_id=0, expert_size=-1):
if expert_size == -1:
expert_size = expert_num
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
bias1, bias2 = None, None
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = self.op_impl_base(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id)
# # (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
w1_scale[start_expert_id:start_expert_id+expert_size],
w2_scale[start_expert_id:start_expert_id+expert_size],
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
def test_fused_moe_with_4D_w2(self):
print("test_fused_moe_with_4D_w2")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
dtype_list = [torch.half]
if mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, True, True, 'silu', dtype
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
torch_out = self.op_impl_base(hidden_states, router_logit,
weight1, weight2,
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states, router_logit,
weight1, weight2,
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, 0)
# (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
weight1, weight2.view(8, 4, hidden_size, inner_size).transpose(1,2).contiguous(),
bias1, bias2,
residual.view(-1, hidden_size) if residual is not None else None,
None, None, None, None,
topk, renormalize, gated, act_mode, 0).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_moe_tp2_mixed_ep4_with_4D_w2(self):
print("test_moe_tp2_mixed_ep4_with_4D_w2")
tp_num, ep_num = 2, 4
batch, seq, hidden_size, inner_size = 3, 5, 8192, 2048
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
assert inner_size % tp_num == 0
assert expert_num % ep_num == 0
expert_size = expert_num // ep_num
assert 4096 % inner_size == 0
block_e = 4096 // inner_size
assert expert_size % block_e == 0
if not mlu.is_bf16_supported():
data_type = torch.float16
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
residual = None
bias1 = None # torch.randn(expert_num, inner_size*(1+gated), device="mlu", dtype=data_type)
bias2 = None # torch.randn(expert_num, hidden_size, device="mlu", dtype=data_type)
w1 = weight1.reshape(expert_num, tp_num, -1, hidden_size)
w2 = weight2.reshape(expert_num, hidden_size, tp_num, -1)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, act_mode: {act_mode} testing...", flush=True)
tmo_out = torch.zeros_like(hidden_states)
torch_out = torch.zeros_like(hidden_states)
for tp_idx in range(tp_num):
new_inner_size = w2.shape[-1]
w1_curr_tp = w1[:, tp_idx, ...]
w2_curr_tp = w2[:, :, tp_idx, :]
for ep_idx in range(ep_num):
start_expert_id = ep_idx * expert_size
w1_curr_tp_and_ep = w1_curr_tp.reshape((ep_num, expert_size)+w1_curr_tp.shape[1:])[ep_idx].contiguous()
w2_curr_tp_and_ep = w2_curr_tp.reshape((ep_num, expert_size)+w2_curr_tp.shape[1:])[ep_idx].contiguous()
torch_out += self.op_impl_base(hidden_states, router_logit, w1_curr_tp_and_ep,
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, start_expert_id, 0, 0, None, None)
tmo_out += ops.fused_moe(hidden_states, router_logit, w1_curr_tp_and_ep,
w2_curr_tp_and_ep.view(-1, block_e, hidden_size, new_inner_size).transpose(1,2).contiguous(),
bias1, bias2, residual, None, None, None, None,
topk, renormalize, gated, act_mode, start_expert_id)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_sq_fused_moe_with_4D_w2(self):
print("test_sq_fused_moe_with_4D_w2")
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 32, 5, False, True, 'gelu', torch.bfloat16
if not mlu.is_bf16_supported():
data_type = torch.float16
assert 4096 % inner_size == 0
block_e = 4096 // inner_size
assert expert_num % block_e == 0
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.normal(0, 0.1, size=(batch, seq, hidden_size), device="mlu", dtype=data_type)
router_logit = torch.normal(0, 0.1, size=(batch, seq, expert_num), device="mlu", dtype=torch.float32)
weight1 = torch.normal(0, 0.1, size=(expert_num, inner_size*(1+gated), hidden_size), device="mlu", dtype=data_type)
weight2 = torch.normal(0, 0.1, size=(expert_num, hidden_size, inner_size), device="mlu", dtype=data_type)
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
bias1, bias2 = None, None
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
torch_out = self.op_impl_base(hidden_states, router_logit, quant_w1, quant_w2,
bias1, bias2, residual, input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0, 0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states, router_logit, quant_w1, quant_w2,
bias1, bias2, residual, input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize, gated, act_mode, 0)
# # (N*T, C)
tmo_out_2 = ops.fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1,
quant_w2.view(-1,block_e,hidden_size, inner_size).transpose(1,2).contiguous(),
bias1, bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth, act_smooth,
w1_scale, w2_scale, topk, renormalize,
gated,
act_mode,
0).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
self.assertTensorsEqual(tmo_out_1.cpu().float(), tmo_out_2.cpu().float(), 0, use_MSE=True)
def test_sq_fused_moe_random_tp_quant_grouped(self):
print("test_sq_fused_moe_random_tp_quant_grouped")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 100):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(512, 2048, 128)
inner_size = random.randrange(512, 2048, 512)
expert_num = random.randint(1, 40)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
quant_bit = random.choice([4, 8])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
self.__run_quant_grouped_case(*case)
def __run_quant_grouped_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode='gelu', start_expert_id=0, expert_size=-1, quant_group_size=128):
def get_quant_group(n, quant_group_size):
quant_group = n // quant_group_size
return quant_group if quant_group >= 1 else 1
if expert_size == -1:
expert_size = expert_num
scale_s = 0.01 # avoid the occurrence of inf
eps = 0.1 # Avoid the occurrence of nan
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type) * scale_s
residual = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type) * scale_s
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type) * scale_s
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.randn(expert_num, hidden_size, device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.randn(expert_num, inner_size, device="mlu", dtype=torch.float32).abs() + eps
bias1, bias2 = None, None
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
w1_quant_group = get_quant_group(hidden_size, quant_group_size)
w2_quant_group = get_quant_group(inner_size, quant_group_size)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), quant_bit, w1_quant_group)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), quant_bit, w2_quant_group)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
if quant_bit == 4:
quant_w1, quant_w2 = PairlyPackInt8(quant_w1), PairlyPackInt8(quant_w2)
w1_scale = w1_scale.view(expert_num, -1, w1_quant_group).permute(2, 0, 1).contiguous()
w2_scale = w2_scale.view(expert_num, -1, w2_quant_group).permute(2, 0, 1).contiguous()
# split scale and transpose
def extract_scale(scale, start_expert, expert_size):
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
torch_out = self.op_impl_base(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id,
0, 0, None, None)
# (N, T, C)
tmo_out_1 = ops.fused_moe(hidden_states,
router_logit,
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id)
tmo_out_3 = fused_moe(hidden_states.view(-1, hidden_size),
router_logit.view(-1, expert_num),
quant_w1[start_expert_id:start_expert_id+expert_size],
quant_w2[start_expert_id:start_expert_id+expert_size],
bias1,
bias2,
residual.view(-1, hidden_size) if residual is not None else None,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk,
renormalize,
gated,
act_mode,
start_expert_id).view(batch, seq, hidden_size)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
self.assertTensorsEqual(tmo_out_3.cpu().float(), torch_out.cpu().float(), 0.01, use_MSE=True)
def __run_w4w8_mixed_case(self, batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode='gelu', start_expert_id=0, expert_size=-1):
if expert_size == -1:
expert_size = expert_num
w1_quant_group = hidden_size // quant_wise
w2_quant_group = inner_size // quant_wise
w1_quant_flag = torch.randint(1, 3, (expert_num, w1_quant_group), dtype=torch.int32) * 4
w2_quant_flag = torch.randint(1, 3, (expert_num, w2_quant_group), dtype=torch.int32) * 4
w1_count = (w1_quant_flag.sum().item() // 4) * (quant_wise // 2) * inner_size*(1+gated)
w2_count = (w2_quant_flag.sum().item() // 4) * (quant_wise // 2) * hidden_size
w1 = torch.randint(-128, 127, (w1_count,), device="mlu", dtype=torch.int32).to(torch.int8)
w2 = torch.randint(-128, 127, (w2_count,), device="mlu", dtype=torch.int32).to(torch.int8)
hidden_states = torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
residual = None # torch.randn(batch, seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch, seq, expert_num, device="mlu", dtype=torch.float32)
input_smooth = torch.empty(expert_num, hidden_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
act_smooth = torch.empty(expert_num, inner_size, device="mlu", dtype=torch.float32).uniform_(0.01, 0.05)
bias1, bias2 = None, None
w1_scale = torch.empty((w1_quant_group, expert_num, inner_size*(1+gated)), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w2_scale = torch.empty((w2_quant_group, expert_num, hidden_size), device="mlu", dtype=torch.float32).uniform_(-0.05, 0.05)
w1_offset_cu = torch.cumsum(w1_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * inner_size*(1+gated)
w1_offset_cu = torch.nn.functional.pad(w1_offset_cu, (1,0), "constant", 0)
w2_offset_cu = torch.cumsum(w2_quant_flag.sum(dim=1), dim=0) // 4 * (quant_wise // 2) * hidden_size
w2_offset_cu = torch.nn.functional.pad(w2_offset_cu, (1,0), "constant", 0)
# split scale and transpose
def extract_scale(scale, start_expert, expert_size):
return scale[:, start_expert:start_expert+expert_size, :].contiguous()
params = [hidden_states, router_logit,
w1[w1_offset_cu[start_expert_id]:w1_offset_cu[start_expert_id+expert_size]],
w2[w2_offset_cu[start_expert_id]:w2_offset_cu[start_expert_id+expert_size]],
bias1, bias2, residual,
input_smooth[start_expert_id:start_expert_id+expert_size],
act_smooth[start_expert_id:start_expert_id+expert_size],
extract_scale(w1_scale, start_expert_id, expert_size),
extract_scale(w2_scale, start_expert_id, expert_size),
topk, renormalize, gated, act_mode, start_expert_id, 0, 0,
w1_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist(),
w2_quant_flag[start_expert_id:start_expert_id+expert_size].flatten().tolist()]
torch_out = self.op_impl_base(*params)
# (N, T, C)
tmo_out_1 = ops.fused_moe(*params)
tmo_out_2 = fused_moe(*params)
self.assertTensorsEqual(tmo_out_1.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
self.assertTensorsEqual(tmo_out_2.cpu().float(), torch_out.cpu().float(), 0.03, use_MSE=True)
def test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed(self):
print("test_sq_fused_moe_random_tp_quant_grouped_w4w8_mixed")
import random
random.seed(0)
act_mode = 'gelu'
case_list = set()
while (len(case_list) < 100):
batch = random.randint(1, 10)
seq = random.randint(1, 10)
hidden_size = random.randrange(1024, 3072, 512)
inner_size = random.randrange(1024, 3072, 512)
expert_num = random.randint(1, 40)
topk = random.randint(1, expert_num)
gated = random.choice([True, False])
renormalize = random.choice([True, False])
quant_wise = random.choice([128, 256, 512])
data_type = random.choice([torch.bfloat16, torch.float16])
if not mlu.is_bf16_supported():
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_wise, act_mode)
if case in case_list:
continue
case_list.add(case)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_wise: {quant_wise}, act_mode: {act_mode} testing...", flush=True)
self.__run_w4w8_mixed_case(*case, 0, -1)
def test_sq_fused_moe_single_quant_group(self):
print("test_sq_fused_moe_single_quant_group")
import random
random.seed(0)
act_mode = 'gelu'
batch = 9
seq = 10
hidden_size = 1664
inner_size = 512
expert_num = 15
topk = 2
gated = False
renormalize = False
quant_bit = 4
data_type = torch.float16
case = (batch, seq, hidden_size, inner_size, expert_num, topk, gated, renormalize, data_type, quant_bit, act_mode)
print(f"bs: {batch}, seq_len: {seq}, hidden_size: {hidden_size}, inner_size: {inner_size}, experts_num: {expert_num}, \
topk: {topk}, gated: {gated}, renormalize: {renormalize}, data_type: {data_type}, quant_bit: {quant_bit}, act_mode: {act_mode} testing...", flush=True)
self.__run_quant_grouped_case(*case)
def test_inductor(self):
batch, seq, hidden_size, inner_size = 3, 5, 8192, 1024
expert_num, topk, gated, renormalize, act_mode, data_type = 8, 2, True, True, 'silu', torch.float16
start_expert_id, expert_size = 0, 8
hidden_states = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
router_logit = torch.randn(batch * seq, expert_num, device="mlu", dtype=torch.float32)
residual = torch.randn(batch * seq, hidden_size, device="mlu", dtype=data_type)
weight1 = torch.randn(expert_num, inner_size*(1+gated), hidden_size, device="mlu", dtype=data_type)
weight2 = torch.randn(expert_num, hidden_size, inner_size, device="mlu", dtype=data_type)
args = (hidden_states, router_logit,
weight1[start_expert_id:start_expert_id+expert_size],
weight2[start_expert_id:start_expert_id+expert_size],
None, None, residual, None, None, None, None, None, None,
topk, renormalize, gated, act_mode, 0, 0, 0)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
eps = 1e-5
input_smooth = torch.normal(0, 0.1, size=(expert_num, hidden_size), device="mlu", dtype=torch.float32).abs() + eps
act_smooth = torch.normal(0, 0.1, size=(expert_num, inner_size), device="mlu", dtype=torch.float32).abs() + eps
weight1_shape, weight2_shape = weight1.shape, weight2.shape
weight1 = weight1 / input_smooth.unsqueeze(1)
weight2 = weight2 / act_smooth.unsqueeze(1)
quant_w1, w1_scale = QuantByRow(weight1.view(-1, weight1.shape[-1]), 8)
quant_w2, w2_scale = QuantByRow(weight2.view(-1, weight2.shape[-1]), 8)
quant_w1, quant_w2 = quant_w1.view(weight1_shape), quant_w2.view(weight2_shape)
w1_scale, w2_scale = w1_scale.view(expert_num, -1), w2_scale.view(expert_num, -1)
args = (hidden_states, router_logit,
quant_w1, quant_w2, None, None, residual,
input_smooth, act_smooth, w1_scale, w2_scale,
None, None, topk, renormalize, gated, act_mode, 0, 0, 0)
self.base_opcheck(torch.ops.torch_mlu_ops.fused_moe, args)
if __name__ == '__main__':
exit(run_unittest(TestFusedMOEOp))

View File

@@ -0,0 +1,242 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
import time
import os
def gen_data(num_expert,
total_tokens,
inner_size,
output_stride,
dtype,
is_gated,
has_bias,
is_ep):
ci = inner_size * (1 + is_gated)
input = torch.randn(total_tokens, ci, dtype=dtype, device='mlu')
cusum_token_count, token_count = generate_token_count(num_expert, total_tokens)
output = torch.empty((total_tokens, inner_size), dtype=dtype, device='mlu')
output.as_strided(output.size(), (output_stride, 1))
start_expert_id = random.randint(0, num_expert - 1) if is_ep else 0
expert_size = random.randint(1, num_expert - start_expert_id) if is_ep else num_expert
bias = torch.randn(num_expert, ci, dtype=dtype, device='mlu') if has_bias else None
return input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size
class TestMoeActiveKernel(BtTestCase):
def op_impl_base(self, *args):
input, act_mode, is_gated, output, bias, cusum_token_count, start_expert_id, expert_size = args
act_fun = torch.nn.functional.gelu if act_mode == 'gelu' else torch.nn.functional.silu
total_token_num = input.size(0)
inner_size = input.size(1) // 2 if is_gated else input.size(1)
input_ = input.clone()
if bias is not None:
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
input_ = input_[:deal_token_num, :]
token_count = cusum_token_count[1:] - cusum_token_count[:-1]
token_count_ = token_count[start_expert_id:start_expert_id+expert_size].tolist()
input_list = list(input_.split(token_count_))
for i in range(expert_size):
input_list[i] += bias[i]
input_ = torch.cat(input_list, dim=0)
if cusum_token_count.size(0) - 1 != expert_size:
pad = torch.zeros(total_token_num-deal_token_num, input_.size(-1)).to(input_.dtype).mlu()
input_ = torch.cat((input_, pad), dim=0)
acted = act_fun(input_[:, :inner_size])
acted = acted * input_[:, inner_size:] if is_gated else acted
if output is None:
return acted
else:
return output.copy_(acted)
# 功能测试
def test_functional(self):
num_expert_list = [5, 32]
total_tokens_list = [64, 1024]
inner_size_list = [1024, 8192]
dtype_list = [torch.half, torch.float32]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
is_gated_list = [True, False]
is_ep_list = [True, False]
# change True when test
has_bias_list = [False, False, True]
act_mode_list = ['silu', 'gelu']
args = product(num_expert_list, total_tokens_list, inner_size_list, dtype_list,
is_gated_list, is_ep_list, has_bias_list, act_mode_list)
for num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode in args:
print("===============================================================================")
print(f"num_expert: {num_expert}, total_tokens: {total_tokens}")
print(f"inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}")
print(f"is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}")
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
base_output = torch.empty_like(output)
ops.moe_active(input,
act_mode,
is_gated,
output,
bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
output = output[:deal_token_num, :]
self.op_impl_base(input,
act_mode,
is_gated,
base_output,
bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
base_output = base_output[:deal_token_num, :]
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inplace(self):
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
32, 80, 1024, torch.half, True, False, False, 'gelu'
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
input_bak = input.clone()
# test output inpalce with stride
output = input.as_strided((total_tokens, inner_size), (inner_size * 2, 1))
base_output = torch.empty_like(output)
ops.moe_active(input,
act_mode,
is_gated,
output,
bias,
None,
start_expert_id,
expert_size)
self.op_impl_base(input_bak, act_mode, is_gated, base_output, bias, None, start_expert_id, expert_size)
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
# 随机遍历测试
def test_random(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for i in range(500):
num_expert = random.randint(1, 64)
total_tokens = random.randint(1, 32768)
inner_size = random.randint(1, 8192)
dtype = random.sample(dtype_list, 1)[0]
is_gated = random.sample([True, False], 1)[0]
is_ep = random.sample([True, False], 1)[0]
# change True when test
has_bias = random.sample([False, True], 1)[0]
act_mode = random.sample(['gelu', 'silu'], 1)[0]
print("===============================================================================")
print(f"[{i}]: num_expert: {num_expert}, total_tokens: {total_tokens}")
print(f" inner_size: {inner_size}, dtype: {dtype}, is_gated: {is_gated}")
print(f" is_ep: {is_ep}, has_bias: {has_bias}, act_mode: {act_mode}")
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
base_output = torch.empty_like(output)
ops.moe_active(input,
act_mode,
is_gated,
output,
bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
output = output[:deal_token_num, :]
self.op_impl_base(input,
act_mode,
is_gated,
base_output,
bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
base_output = base_output[:deal_token_num, :]
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_single(self):
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
32, 64, 8192, torch.float, True, False, False, 'gelu'
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
base_output = torch.empty_like(output)
ops.moe_active(input,
act_mode,
is_gated,
output,
bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
deal_token_num = cusum_token_count[start_expert_id+expert_size] - cusum_token_count[start_expert_id]
output = output[:deal_token_num, :]
self.op_impl_base(input, act_mode, is_gated, base_output, bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id,
expert_size)
base_output = base_output[:deal_token_num, :]
self.assertTensorsEqual(base_output.cpu().float(), output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = ops.moe_active
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
32, 64, 8192, torch.float, True, False, False, 'gelu'
input, bias, token_count, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
act_mode = 'abc'
self.assertException("act_mode must be 'silu', 'gelu', 'quick_gelu' or 'swish'.", func,
input, act_mode, is_gated, output, bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id, expert_size)
act_mode = 'gelu'
self.assertException("input.dim() >= 2", func,
input.reshape(-1), act_mode, is_gated, output, bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id, expert_size)
is_gated = True
self.assertException("in_channel % 2 == 0 if is_gated is true", func,
torch.randn(total_tokens, inner_size * 2 - 1, dtype=dtype, device='mlu'),
act_mode, is_gated, output, bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
start_expert_id, expert_size)
def test_inductor(self):
num_expert, total_tokens, inner_size, dtype, is_gated, is_ep, has_bias, act_mode = \
32, 64, 8192, torch.float, True, False, True, 'gelu'
input, bias, _, cusum_token_count, output, start_expert_id, expert_size = \
gen_data(num_expert, total_tokens, inner_size, inner_size, dtype, is_gated, has_bias, is_ep)
bias_real = bias[start_expert_id:start_expert_id + expert_size] if has_bias else None
args = (input, output, bias_real,
cusum_token_count.mlu() if has_bias or is_ep else None,
act_mode, is_gated, start_expert_id, expert_size)
self.base_opcheck(torch.ops.torch_mlu_ops.active, args)
if __name__ == '__main__':
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestMoeActiveKernel))

View File

@@ -0,0 +1,77 @@
import torch
import unittest
import torch_mlu_ops as ops
import random
from common_utils import *
from itertools import product
import copy
from typing import Optional
class TestMoeCastGating(BtTestCase):
def op_impl_base(self, *args):
input, weight = args
input = input.to(torch.float)
output = torch.matmul(input, weight.permute(1, 0))
return output
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
def test_moe_cast_gating_random(self):
for _ in range(1000):
total_seq = random.randint(1, 32768)
hidden_size = random.randint(1, 16384)
expert_num = random.randint(1, 128)
input_dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
input_dtype_list.append(torch.bfloat16)
input_dtype = random.choice(input_dtype_list)
weight_dtype = torch.float
print("total_seqlen={}, hidden_size={}, expert_num={}, input_dtype={}, testing...".format(
total_seq, hidden_size, expert_num, input_dtype))
input = torch.randn(total_seq, hidden_size, dtype=input_dtype, device="mlu")
weight = torch.randn(expert_num, hidden_size, dtype=weight_dtype, device="mlu")
tmo_out = ops.moe_cast_gating(input, weight)
torch_out = self.op_impl_base(input, weight)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 1e-4,
use_MSE=True, use_RAE=True)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = ops.moe_cast_gating
input = torch.randn(1024, 8192, dtype=torch.half, device="mlu")
input = input.as_strided(input.shape, (100, 1))
weight = torch.randn(128, 8192, dtype=torch.float, device="mlu")
self.assertException("input must be contiguous", func, input, weight)
input = input.contiguous()
weight = weight.as_strided(weight.shape, (100, 1))
self.assertException("weight must be contiguous", func, input, weight)
weight = weight.contiguous()
weight = weight.reshape(1, 128, 8192)
self.assertException("weight.dim() == 2", func, input, weight)
weight = torch.randn(128, 2048, dtype=torch.float, device="mlu")
self.assertException("input.size(-1) == weight.size(-1)", func, input, weight)
weight = torch.randn(128, 8192, dtype=torch.half, device="mlu")
self.assertException("weight type need be torch::kFloat32", func, input, weight)
weight = weight.to(torch.float)
input = input.to(torch.float)
self.assertException("input type need be torch::kFloat16 or torch::kBFloat16", func, input, weight)
input = torch.randn(1024, 16388, dtype=torch.half, device="mlu")
weight = torch.randn(128, 16388, dtype=torch.float, device="mlu")
self.assertException("hidden_size > 0 && hidden_size <= 16384", func, input, weight)
input = torch.randn(1024, 16384, dtype=torch.half, device="mlu")
weight = torch.randn(129, 16384, dtype=torch.float, device="mlu")
self.assertException("expert_num > 0 && expert_num <= 128", func, input, weight)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "moe_cast_gating not support MLU3XX device")
def test_inductor(self):
m, hidden_size, expert_num, input_dtype = 1024, 4096, 32, torch.half
input = torch.randn(m, hidden_size, dtype=input_dtype, device="mlu")
weight = torch.randn(expert_num, hidden_size, dtype=torch.float, device="mlu")
args = (input, weight)
self.base_opcheck(torch.ops.torch_mlu_ops.moe_cast_gating, args)
if __name__ == "__main__":
if "MLU3" not in torch.mlu.get_device_name():
exit(run_unittest(TestMoeCastGating))

View File

@@ -0,0 +1,242 @@
import torch
import unittest
import torch_mlu_ops as ops
import random
from common_utils import *
from itertools import product
import copy
from typing import Optional
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \
dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count
def gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
has_bias,
has_residual,
dtype,
device):
input = torch.randn((num_tokens * topk, hidden_size), dtype=dtype, device=device)
reduce_weight = torch.randn((num_tokens, topk), dtype=torch.float32, device=device)
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32, device=device)
bias = None
residual = None
cusum_token_count = None
if has_bias:
bias = torch.randn((num_expert, hidden_size), dtype=dtype, device=device)
if has_residual:
residual = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
if has_bias or expert_size < num_expert:
cusum_token_count = generate_token_count(num_expert, num_tokens * topk)
cusum_token_count = cusum_token_count.to(device=device)
return input, reduce_weight, gather_ids, residual, bias, cusum_token_count
class TestCombineResult(BtTestCase):
def op_impl_base(self, *args):
input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, \
expert_size, bias = args
input = input.to(dtype=torch.float32).cpu()
reduce_weight = reduce_weight.cpu()
gather_ids = gather_ids.cpu()
bias = None
dtype = input.dtype
num_expert = expert_size
hidden_size = input.shape[1]
if bias is not None:
bias = bias.to(dtype=torch.float32).cpu()
num_expert = bias.shape[0]
if cusum_token_count is not None:
num_expert = cusum_token_count.shape[0] - 1
cusum_token_count = cusum_token_count.cpu()
if cusum_token_count is not None and expert_size < num_expert:
gathered_input = input[gather_ids - cusum_token_count[start_expert_id].item()]
else:
gathered_input = input[gather_ids]
if bias is not None and cusum_token_count is not None:
for i in range(start_expert_id, start_expert_id + expert_size):
gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] += bias[i]
gathered_input = gathered_input.reshape(*reduce_weight.shape, hidden_size)
if cusum_token_count is not None:
filtered_ids = (gather_ids >= cusum_token_count[start_expert_id]) * \
(gather_ids < cusum_token_count[start_expert_id + expert_size])
filtered_ids = filtered_ids.to(dtype=torch.float32)
reduce_weight = reduce_weight * filtered_ids.reshape(reduce_weight.shape)
gathered_input *= reduce_weight.reshape(*reduce_weight.shape, -1)
output = torch.sum(gathered_input, dim=1, keepdim=False)
if residual is not None:
residual = residual.to(dtype=torch.float32).cpu()
output += residual
return output.to(dtype=dtype)
def test_random_case(self):
torch.manual_seed(444)
test_cases = 200
num_tokens_list = torch.randint(low=1, high=2048, size=(test_cases, ), dtype=torch.int32)
topk_list = torch.randint(low=1, high=33, size=(test_cases, ), dtype=torch.int32)
hidden_size_list = torch.randint(low=256, high=8193, size=(test_cases, ), dtype=torch.int32)
num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
num_expert_list = torch.maximum(topk_list, num_expert_list)
expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
expert_size_list = torch.minimum(expert_size_list, num_expert_list)
start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32)
start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list)
start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32))
has_bias_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
has_residual_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32)
dtype_map = [torch.half, torch.bfloat16, torch.float]
device = 'mlu'
mlu_name = torch.mlu.get_device_name()
is_mlu370 = "MLU3" in mlu_name
max_num_tokens = 128 * 1024
for i in range(test_cases):
num_tokens = num_tokens_list[i].item()
topk = topk_list[i].item()
hidden_size = hidden_size_list[i].item()
num_expert = num_expert_list[i].item()
expert_size = expert_size_list[i].item()
start_expert_id = start_expert_id_list[i].item()
has_bias = False # has_bias_list[i].item()
has_residual = has_residual_list[i].item()
topk = min(topk, (int)(max_num_tokens / num_tokens))
if dtype_list[i].item() < 4:
dtype = dtype_map[0]
elif dtype_list[i].item() < 9:
dtype = dtype_map[1]
else:
dtype = dtype_map[2]
if is_mlu370 and dtype is torch.bfloat16:
continue
inputs = gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
has_bias,
has_residual,
dtype,
device)
input = inputs[0]
reduce_weight = inputs[1]
gather_ids = inputs[2]
residual = inputs[3]
bias = inputs[4]
cusum_token_count = inputs[5]
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
"start_expert_id={}, has_bias={}, has_residual={}, dtype={}, testing...".format(
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
has_bias, has_residual, dtype))
golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual,
cusum_token_count, start_expert_id, expert_size, None)
output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual,
cusum_token_count, start_expert_id, expert_size)
self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003,
"golden_output must equal output", True, True, True, True)
def test_perf_case(self):
num_tokens_list = [1, 72, 512]
hidden_size_list = [2048, 4096, 5120, 8192]
# [num_expert, topk, start_expert_id, expert_size]
expert_options_list = [[1, 1, 0, 1], [8, 2, 0, 8], [32, 5, 0, 32], [32, 5, 24, 8]]
has_residual_list = [True, False]
dtype_list = [torch.half, torch.bfloat16]
device = 'mlu'
mlu_name = torch.mlu.get_device_name()
is_mlu370 = "MLU3" in mlu_name
args = product(num_tokens_list, hidden_size_list, expert_options_list,\
has_residual_list, dtype_list)
for num_tokens, hidden_size, expert_options, has_residual, dtype in args:
num_expert = expert_options[0]
topk = expert_options[1]
start_expert_id = expert_options[2]
expert_size = expert_options[3]
has_bias = False
if is_mlu370 and dtype is torch.bfloat16:
continue
torch.manual_seed(444)
inputs = gen_case(num_tokens,
topk,
hidden_size,
num_expert,
expert_size,
has_bias,
has_residual,
dtype,
device)
input = inputs[0]
reduce_weight = inputs[1]
gather_ids = inputs[2]
residual = inputs[3]
bias = inputs[4]
cusum_token_count = inputs[5]
golden_output = self.op_impl_base(input, reduce_weight, gather_ids, residual,
cusum_token_count, start_expert_id, expert_size, None)
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
loop = 10
for _ in range(loop):
output = ops.moe_combine_result(input, reduce_weight, gather_ids, residual,
cusum_token_count, start_expert_id, expert_size)
notify_end.record()
notify_end.synchronize()
time = notify_start.hardware_time(notify_end) / loop
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
"start_expert_id={}, has_bias={}, has_residual={}, dtype={}, time={:.1f}".format(
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
has_bias, has_residual, dtype, time))
self.assertTensorsEqual(output.cpu().float(), golden_output.cpu().float(), 0.003,
"golden_output must equal output", True, True, True, True)
def test_inductor(self):
num_tokens, hidden_size, has_bias, has_residual, dtype = 1, 2048, False, True, torch.float16
num_expert, topk, start_expert_id, expert_size = 8, 2, 0, 8
input, reduce_weight, gather_ids, residual, \
bias, cusum_token_count = gen_case(num_tokens, topk, hidden_size, num_expert, expert_size,
has_bias, has_residual, dtype, 'mlu')
args = (input, reduce_weight, gather_ids, residual, cusum_token_count, start_expert_id, expert_size, bias)
self.base_opcheck(torch.ops.torch_mlu_ops.moe_combine_result, args)
if __name__ == '__main__':
exit(run_unittest(TestCombineResult))

View File

@@ -0,0 +1,132 @@
import torch
import torch_mlu_ops
import unittest
import torch_mlu_ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
from typing import Tuple, Optional
import os
class TestExpandInput(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
input = create_tensor_from_dic(dic['input'])
gather_idx = dic['gather_idx']['data']
cusum_token_count = dic['cusum_token_count']['data']
start_expert_id = dic['start_expert_id']['data']
expert_size = dic['expert_size']['data']
self.launch(input, gather_idx, cusum_token_count, start_expert_id, expert_size)
def launch(self, *args):
cusum_token_count = args[2]
start_expert_id = args[3]
expert_size = args[4]
real_token_count = None if cusum_token_count is None else cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id]
base_out = self.op_impl_base(*args)
tmo_out = torch_mlu_ops.moe_expand_input(*args)
self.assertTensorsEqual(base_out[:real_token_count].cpu().float(), tmo_out[:real_token_count].cpu().float(),
0.00, use_MSE=True, use_RAE=True)
def op_impl_base(self, *args):
input, gather_idx, cusum_token_count, start_expert_id, expert_size = args
if cusum_token_count is None:
return input[gather_idx]
else:
idx = gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id + expert_size]]
return input[idx]
def get_tensor(self, token_num, hidden_size, expert_num, topk, start_expert_id, expert_size, dtype):
input = torch.randn(token_num, hidden_size, device='mlu').to(dtype)
gather_idx = torch.randint(low=0, high=token_num, size=(token_num * topk,), dtype=torch.int32, device='mlu')
cusum_token_count, _ = generate_token_count(expert_num, token_num * topk)
cusum_token_count = cusum_token_count.to('mlu')
use_all_experts = expert_num == expert_size
if use_all_experts:
cusum_token_count = None
real_token_count = token_num * topk
else:
real_token_count = cusum_token_count[start_expert_id+expert_size-1] - cusum_token_count[start_expert_id]
return input, gather_idx, cusum_token_count, real_token_count
def test_kernel_random(self):
for i in range(100):
token_num = random.randint(1, 2048)
hidden_size = random.randint(1, 4096)
expert_num = random.randint(1, 32)
topk = random.randint(1, expert_num)
start_expert_id = random.randint(0, expert_num-1)
expert_size = random.randint(1, expert_num-start_expert_id)
dtype_list = [torch.half, torch.float, torch.int8]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
dtype = random.sample(dtype_list, 1)[0]
print("===============================================================================")
print(f"[{i}]: token_num: {token_num}, hidden_size: {hidden_size}, expert_num: {expert_num}")
print(f" topk: {topk}, start_expert_id: {start_expert_id}, expert_size: {expert_size}, dtype: {dtype}")
input, gather_idx, cusum_token_count, real_token_count = self.get_tensor(token_num,
hidden_size, expert_num, topk, start_expert_id, expert_size, dtype)
base_expand_hidden_states = self.op_impl_base(input, gather_idx, None, 0, 0)
expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx)
self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(),
0.00, use_MSE=True, use_RAE=True)
del base_expand_hidden_states, expand_hidden_states
base_expand_hidden_states = self.op_impl_base(input, gather_idx, cusum_token_count, start_expert_id, expert_size)[:real_token_count]
expand_hidden_states = torch_mlu_ops.moe_expand_input(input, gather_idx, cusum_token_count,
start_expert_id, expert_size)[:real_token_count]
self.assertTensorsEqual(base_expand_hidden_states.cpu().float(), expand_hidden_states.cpu().float(),
0.00, use_MSE=True, use_RAE=True)
del base_expand_hidden_states, expand_hidden_states
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = torch_mlu_ops.moe_expand_input
token_num, hidden_size, topk, expert_num, start_expert_id, expert_size, dtype = 5, 5, 6, 32, 0, 16, torch.half
input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \
topk, start_expert_id, expert_size, dtype)
input = torch.randn(token_num, 1, hidden_size).to(dtype).to('mlu')
self.assertException("input dim must be equal to 2.", func, input, gather_idx)
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
gather_idx = torch.randint(0, token_num, (token_num * topk, 1)).to(torch.int32).to('mlu')
self.assertException("gather_idx dim must be equal to 1.", func, input, gather_idx)
gather_idx = gather_idx.reshape(token_num * topk)
self.assertException("input tensor must on mlu.", func, input.cpu(), gather_idx)
self.assertException("gather_idx must on mlu.", func, input, gather_idx.cpu())
self.assertException("data type of gather_idx must be int32.", func, input, gather_idx.to(torch.int64))
input = torch.randn(token_num, hidden_size).to(dtype).to('mlu')
gather_idx = torch.randint(0, token_num, (8,)).to(torch.int32).to('mlu')
self.assertException("expand_token_num % token_num == 0.", func, input, gather_idx)
gather_idx = torch.randint(1, token_num, (token_num * topk,)).to(torch.int32).to('mlu')
self.assertException("cusum_token_count must on mlu.", func, input, gather_idx, cusum_token_count.cpu(),
start_expert_id, expert_size)
self.assertException("data type of cusum_token_count must be int32.", func, input, gather_idx, cusum_token_count.to(torch.int64),
start_expert_id, expert_size)
start_expert_id = -1
self.assertException("start_expert_id >=0 && start_expert_id < expert_num.",
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
start_expert_id = expert_num
self.assertException("start_expert_id >=0 && start_expert_id < expert_num.",
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
start_expert_id = 16
expert_size = 17
self.assertException("start_expert_id + expert_size <= expert_num.",
func, input, gather_idx, cusum_token_count, start_expert_id, expert_size)
def test_inductor(self):
dtype = torch.float16
token_num, hidden_size, expert_num, topk, start_expert_id, expert_size = 64, 128, 16, 8, 3, 10
input, gather_idx, cusum_token_count, _ = self.get_tensor(token_num, hidden_size, expert_num, \
topk, start_expert_id, expert_size, dtype)
args = (input, gather_idx, cusum_token_count, start_expert_id, expert_size)
self.base_opcheck(torch.ops.torch_mlu_ops.moe_expand_input, args)
if __name__ == "__main__":
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestExpandInput))

View File

@@ -0,0 +1,62 @@
import torch
import torch_mlu_ops
import unittest
import torch_mlu_ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
from typing import Tuple, Optional
import os
def gen_args(index: int = 0):
token_num = random.randint(1, 32768)
expert_num = random.randint(1, 256)
topk = random.randint(1, expert_num)
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk)).to(torch.int32).to('mlu')
print("===============================================================================")
print(f"[{index}]: token_num: {token_num}, expert_num: {expert_num}, topk: {topk}")
return expert_id, expert_num
class TestGenIdx(BtTestCase):
def op_impl_base(self, *args):
expert_id, expert_num = args
token_num, topk = expert_id.size(0), expert_id.size(1)
sorted_expert_id, indices = expert_id.int().flatten().sort()
expand_idx_out = indices.int() // topk
combine_idx_out = torch.zeros((token_num * topk,), dtype=torch.int, device="mlu")
combine_idx_out.scatter_(0, indices, torch.arange(token_num * topk, dtype=torch.int, device="mlu"))
token_count_out = torch.bincount(sorted_expert_id, minlength=expert_num)
cusum_token_count_out = torch.cat((torch.tensor([0]).to('mlu'), torch.cumsum(token_count_out, dim=0))).to(torch.int32)
return tuple([expand_idx_out, combine_idx_out, token_count_out, cusum_token_count_out])
def test_kernel_random(self):
for i in range(1500):
expert_id, expert_num = gen_args(i)
base_gather_expand_idx, base_gather_combine_idx, base_token_count, base_cusum_token_count = self.op_impl_base(expert_id, expert_num)
gather_expand_idx_out, gather_combine_idx_out, token_count_out, cusum_token_count_out = \
torch_mlu_ops.moe_gen_idx(expert_id, expert_num)
self.assertTensorsEqual(base_gather_expand_idx.cpu(), gather_expand_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_gather_combine_idx.cpu(), gather_combine_idx_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_token_count.cpu(), token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_cusum_token_count.cpu(), cusum_token_count_out.cpu(), 0.00, use_MSE=True, use_RAE=True)
def test_inductor(self):
args = gen_args()
self.base_opcheck(torch.ops.torch_mlu_ops.moe_gen_idx, args)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = torch_mlu_ops.moe_gen_idx
token_num, expert_num, topk = 128, 32, 16
expert_id = torch.randint(low=0, high=expert_num, size=(token_num, topk, 1)).to(torch.int32).to('mlu')
self.assertException("expert_id dim must be equal to 2.", func, expert_id, expert_num)
expert_id = expert_id.reshape(token_num, topk)
self.assertException("data type of expert_id must be int32.", func, expert_id.to(torch.int64), expert_num)
if __name__ == "__main__":
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestGenIdx))

View File

@@ -0,0 +1,200 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from common_utils import *
import random
from itertools import product
import time
import numpy as np
import os
class TestSoftmaxTopkOp(BtTestCase):
def op_impl_base(self, *args):
input, topk, normalize, num_expert_group, topk_group, origin_mask, normed_by = args
softmax = torch.softmax(input.float(), dim=-1)
if num_expert_group <= 1:
if origin_mask is not None:
softmax = softmax * origin_mask
reduce_weight, expert_id = torch.topk(softmax, k=topk, dim=-1)
if normalize:
if normed_by == "topk_logit":
reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True)
if normed_by == "softmax_logit":
reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True)
return reduce_weight, expert_id
else:
group_size = softmax.shape[-1] // num_expert_group
new_shape = softmax.shape[:-1] + (num_expert_group, group_size)
group_data = softmax.view(new_shape)
group_max_value = group_data.max(dim=-1).values
group_idx = torch.topk(group_max_value, k=topk_group, dim=-1)[1]
mask_shape = softmax.shape[:-1] + (num_expert_group,)
mask = torch.zeros((mask_shape), dtype = torch.bool, device = group_idx.device)
mask.scatter_(-1, group_idx, True)
mask = mask.unsqueeze(-1).expand(new_shape)
masked_data = group_data.masked_fill(~mask, 0.0)
masked_data = masked_data.reshape(softmax.shape)
reduce_weight, expert_id = torch.topk(masked_data, k=topk, dim=-1)
if normalize:
if normed_by == "topk_logit":
reduce_weight = reduce_weight / reduce_weight.sum(dim=-1, keepdim=True)
if normed_by == "softmax_logit":
reduce_weight = reduce_weight / softmax.sum(dim=-1, keepdim=True)
return reduce_weight, expert_id
# 接口测试
def test_interface(self):
num_token, num_expert, topk, num_expert_group, topk_group, normalize = 1024, 160, 6, 10, 5, False
input = torch.randn(num_token, num_expert, dtype=torch.half, device='mlu')
mask = None
normed_by = "topk_logit"
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
reduce_weight, expert_id = torch_mlu_ops.moe_softmax_topk(input, topk, normalize, num_expert_group, topk_group)
base_expert_id, _ = base_expert_id.sort()
expert_id, _ = expert_id.sort()
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
# 随机遍历测试
def test_kernel_random(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for i in range(1000):
num_batch = random.randint(1, 64)
num_mask = random.randint(1, 1024)
num_expert = random.randint(1, 512)
factors = [i for i in range(1, num_expert + 1) if num_expert % i == 0]
num_expert_group = random.choice(factors)
topk_group = random.randint(1, num_expert_group)
topk = random.randint(1, num_expert / num_expert_group * topk_group)
normalize = random.sample([True, False], 1)[0]
dtype = random.sample(dtype_list, 1)[0]
group_invalid = random.sample([True, False], 1)[0]
if group_invalid:
num_expert_group = -1
normed_by = random.choice(["topk_logit", "softmax_logit"])
print("===============================================================================")
print(f"[{i}]: num_batch: {num_batch}, num_mask: {num_mask}, num_expert: {num_expert}, num_expert_group: {num_expert_group}")
print(f" topk_group: {topk_group}, topk: {topk}, normalize: {normalize}, dtype: {dtype}, normed_by: {normed_by}")
input = torch.randn(num_batch, num_mask, num_expert, dtype=dtype).mlu()
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = dtype).mlu()
if num_expert_group > 1:
mask = None
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input,
topk,
num_expert_group,
topk_group,
normalize,
mask,
normed_by)
# softmax后的值可能因数值差异较小造成topk的值存在顺序上的差异例如
# base_reduce_weight[N, 10:12] = [0.012, 0.011]
# reduce_weight[N, 10:12] = [0.011, 0.012]
# base_expert_id[N, 10:12] = [7, 8]
# expert_id[N, 10:12] = [8, 7]
# 这种情况产生的顺序上的差异不是错误的,但这样会造成结果对比错误,因此需要对结果先排序再对比结果
base_expert_id, _ = base_expert_id.sort()
expert_id, _ = expert_id.sort()
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
# 防呆测试
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = torch.ops.torch_mlu_ops.moe_softmax_topk
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 2, 1024, 160, 0, 1, 1, True, "abc_logit"
input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu')
input_permute = input.permute(0, 2, 1)
mask = torch.randint(0, 2, (num_mask, num_expert), dtype = torch.half, device='mlu')
mask_permute = mask.permute(1, 0)
self.assertException("input must be contiguous.",
func, input_permute, topk, num_expert_group, topk_group, normalize, mask, normed_by)
self.assertException("mask must be contiguous.",
func, input, topk, num_expert_group, topk_group, normalize, mask_permute, normed_by)
self.assertException("normed_by must be 'topk_logit' or 'softmax_logit'",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
input = torch.randn(num_expert, dtype = torch.half, device = 'mlu')
normed_by = "softmax_logit"
self.assertException("input.dim() >= 2",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
input = torch.randn(num_batch, num_mask, num_expert, dtype = torch.half, device='mlu')
self.assertException("topk > 0 && topk <= num_expert",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
topk = 5
self.assertException("the dim of mask should be the same as input",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = torch.randint(0, 2, (1, num_mask, num_expert - 1), dtype = torch.half, device='mlu')
self.assertException("the last dim of mask should be the same as the last dim of input",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = torch.randint(0, 2, (1, num_mask - 1, num_expert), dtype = torch.half, device='mlu')
self.assertException("the penultimate dim of mask should be the same as the penultimate dim of input",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = torch.randint(0, 2, (2, num_mask, num_expert), dtype = torch.half, device='mlu')
self.assertException("the product of all but the lower two dimensions of mask is 1",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.float, device='mlu')
self.assertException("the dtype of mask should be the same as input",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu')
num_expert_group = 8
self.assertException("if num_expert_group > 1, mask should be None",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
mask = None
topk = 160
self.assertException("topk <= (num_expert / num_expert_group) * topk_group",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
num_expert_group = 11
self.assertException("num_expert % num_expert_group == 0",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
num_expert_group = 8
topk_group = 9
self.assertException("topk_group > 0 && topk_group <= num_expert_group",
func, input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
# 单条测例
def test_single(self):
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 1, 1, True, "softmax_logit"
input = torch.randn(2, num_batch, num_mask, num_expert, dtype = torch.float32, device='mlu')
mask = torch.randint(0, 2, (1, 1, num_mask, num_expert), dtype = torch.float32, device='mlu')
# mask = None
base_reduce_weight, base_expert_id = self.op_impl_base(input, topk, normalize, num_expert_group, topk_group, mask, normed_by)
reduce_weight, expert_id = torch.ops.torch_mlu_ops.moe_softmax_topk(input,
topk,
num_expert_group,
topk_group,
normalize,
mask,
normed_by)
base_expert_id, _ = base_expert_id.sort()
expert_id, _ = expert_id.sort()
self.assertTensorsEqual(base_reduce_weight.cpu().float(), reduce_weight.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_expert_id.cpu().float(), expert_id.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
num_batch, num_mask, num_expert, topk, num_expert_group, topk_group, normalize, normed_by = 32, 16, 128, 34, 4, 3, True, "softmax_logit"
input = torch.randn(num_batch, num_mask, num_expert, dtype=torch.half, device='mlu')
mask = torch.randint(0, 2, (1, num_mask, num_expert), dtype = torch.half, device='mlu')
normed_by = "softmax_logit"
if num_expert_group > 1:
mask = None
args = (input, topk, num_expert_group, topk_group, normalize, mask, normed_by)
self.base_opcheck(torch.ops.torch_mlu_ops.moe_softmax_topk, args)
if __name__ == '__main__':
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestSoftmaxTopkOp))

View File

@@ -0,0 +1,194 @@
import torch
import unittest
import torch_mlu_ops as ops
import random
from common_utils import *
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True)
cache_scale = args[10]
args = args[0:10]
args.insert(4, cache_scale[0])
args.insert(5, cache_scale[1])
args.insert(8, False)
return args
class TestOfflineQuantToLinearCache(BtTestCase):
def op_impl_base(self, *args):
def quant2Int8(input_fp: torch.Tensor,
scale_fp: torch.tensor,
quant_mode: torch.int64):
head_num = input_fp.size(0)
seq = input_fp.size(1)
head_size = input_fp.size(2)
input_fp32 = input_fp.to(torch.float32)
if quant_mode == 0: # per_channel
scale = scale_fp.reshape((head_num, 1, head_size))
else:
scale = scale_fp.reshape((head_num, seq, 1))
scaled_context = input_fp32 / scale
rounded = torch.round(scaled_context)
clipped = torch.clip(rounded, -128, 127)
return clipped.to(torch.int8)
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \
context_lengths, max_context_len, quant_mode, packed, context_seq_offset, cache_bs_id, \
cache_seqlen_offset = args
batch_size = context_lengths.size(0) - 1 if packed else context_lengths.size(0)
for i in range(batch_size):
if packed:
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i+1] - context_lengths[i]
context_seq_offset_i = 0
else:
key_i = key[i].transpose(1, 0)
value_i = value[i].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i]
context_seq_offset_i = context_seq_offset[i] if context_seq_offset is not None else 0
cache_bs_id_i = cache_bs_id[i] if cache_bs_id is not None else i
cache_seq_begin = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
if cache_bs_id_i < 0 or cache_seq_begin < 0:
continue
cache_seq_end = cache_seq_begin + context_len_i
# quant key to int8
key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
if quant_mode == 0:
key_cache_scale_i = key_cache_quant_scale
else:
key_cache_scale_i = key_cache_quant_scale[:, cache_seq_begin:cache_seq_end]
quant_key_i = quant2Int8(key_i, key_cache_scale_i, quant_mode)
key_cache_i = key_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end]
key_cache_i[...] = quant_key_i
# quant value to int8
if value_cache is not None and value is not None and value_cache_quant_scale is not None:
value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
if quant_mode == 0:
value_cache_scale_i = value_cache_quant_scale
else:
value_cache_scale_i = value_cache_quant_scale[:, cache_seq_begin:cache_seq_end]
quant_value_i = quant2Int8(value_i, value_cache_scale_i, quant_mode)
value_cache_i = value_cache[cache_bs_id_i, :, cache_seq_begin:cache_seq_end]
value_cache_i[...] = quant_value_i
return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale)
def test_offline_quant_to_linear_cache(self):
test_cases = 100
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
head_size_list *= 16
cache_memory_len_list = torch.randint(low=2, high=1024, size=(test_cases, ), dtype=torch.int32)
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
mode_list= torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
dtype_map = [torch.half, torch.bfloat16, torch.float32]
for i in range(test_cases):
q_heads = 1
batch_size = bs_list[i].item()
invalid_batch = batch_size // 10
num_heads = num_heads_list[i].item()
head_size = head_size_list[i].item()
cache_memory_len = cache_memory_len_list[i].item()
packed = packed_list[i].item()
quant_mode = mode_list[i].item()
total_heads = q_heads + num_heads * 2
dtype = dtype_map[dtype_list[i]]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and dtype == torch.bfloat16:
dtype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, mode ={}, dtype={} testing...".format(
batch_size, num_heads, head_size, cache_memory_len, packed >
0, quant_mode, dtype))
torch.manual_seed(1)
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 2,
dtype=torch.int32, device='mlu')
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch_size, ),
low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=0,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
context = context.to(dtype)
key = context[..., q_heads:q_heads + num_heads, :]
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
# prepare key_scale and value_scale
if quant_mode == 0 : # per_channel
cache_scale = torch.randn((2, num_heads, head_size), dtype=torch.float, device='mlu')
else:
cache_scale = torch.randn((2, num_heads, cache_memory_len), dtype=torch.float, device='mlu')
key_cache_scale = cache_scale[0]
value_cache_scale = cache_scale[1]
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
cache = (cache - 0.5) * 256
cache = cache.to(torch.int8)
key_cache = cache[0]
value_cache = cache[1]
ref_cache = cache.clone()
ref_key_cache = ref_cache[0]
ref_value_cache = ref_cache[1]
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
if packed > 0:
ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, cu_context_lens, max_context_len,
quant_mode, packed > 0, None, cache_bs_id,
cache_seq_offsets)
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale,
value_cache_scale, cu_context_lens, max_context_len,
quant_mode, packed > 0, None, cache_bs_id,
cache_seq_offsets)
else:
ops.offline_quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lens, max_context_len,
quant_mode, packed > 0, context_seq_offsets, cache_bs_id,
cache_seq_offsets)
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, key_cache_scale,
value_cache_scale, context_lens, max_context_len,
quant_mode, packed > 0, context_seq_offsets, cache_bs_id,
cache_seq_offsets)
# for debug
cache = cache.cpu().flatten()
ref_cache = ref_cache.cpu().flatten()
diff = cache - ref_cache
diff = diff.abs()
assert torch.max(diff) < 2, "ref_cache must equal cache or absolute values differ by 1 due to round_mode!"
def test_inductor(self):
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args)
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_linear_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestOfflineQuantToLinearCache))

View File

@@ -0,0 +1,209 @@
import torch
import unittest
import torch_mlu_ops as ops
import random
from common_utils import *
import numpy as np
def quant2int8(input_fp: torch.Tensor,
scale_fp: torch.Tensor):
input_fp32 = input_fp.to(torch.float32)
scaled_input = input_fp32 / scale_fp
rounded = torch.round(scaled_input)
clipped = torch.clip(rounded, -128, 127)
return clipped.to(torch.int8)
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True, True)
cache_scale = args[10]
slot_mapping = args[11]
args = args[0:4]
args.insert(2, cache_scale[0])
args.insert(3, cache_scale[1])
args.insert(4, slot_mapping)
return args
class TestOfflineQuantToPagedCache(BtTestCase):
def op_impl_base(self, *args):
k, v, k_cache_scale, v_cache_scale, slot_mapping, k_cache, v_cache = args
tokens_num = k.shape[0]
block_size = k_cache.shape[2]
for i in range(tokens_num):
if slot_mapping[i] >= 0:
key_i = k[i]
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
key_cache_i = k_cache[block_id, :, block_offset, :]
quant_key_i = quant2int8(key_i, k_cache_scale)
key_cache_i[...] = quant_key_i
if v is not None:
value_i = v[i]
value_cache_i = v_cache[block_id, :, block_offset, :]
quant_value_i = quant2int8(value_i, v_cache_scale)
value_cache_i[...] = quant_value_i
return (k_cache, v_cache) if v is not None else k_cache
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
def test_offline_quant_to_paged_cache(self):
test_cases = 10
token_list = torch.randint(low=1, high=512, size=(test_cases, ), dtype=torch.int32)
head_num_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
head_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
block_size_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
dtype_map = [torch.bfloat16, torch.half, torch.float32]
only_quant_key_list = [True, False]
for i in range(test_cases):
tokens_num = token_list[i].item()
head_num = head_num_list[i].item()
head_size = head_size_list[i].item()
block_size = block_size_list[i].item()
dtype = dtype_map[dtype_list[i]]
print("tokens_num={}, head_num={}, head_size={}, block_size={}, dtype={} testing...".format(
tokens_num, head_num, head_size, block_size, dtype))
np.random.seed(1)
only_quant_key = random.choice(only_quant_key_list)
key_data = np.random.uniform(-1, 1, size=[tokens_num, head_num, head_size])
key_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size])
key = torch.tensor(key_data, dtype=dtype, device="mlu")
key_cache_scale = torch.tensor(key_cache_scale_data, dtype=torch.float32, device="mlu")
value_data = np.random.uniform(-0.25, 0.25, size=[tokens_num, head_num, head_size])
value_cache_scale_data = np.random.uniform(-10, 10, size=[head_num, head_size])
value = torch.tensor(value_data, dtype=dtype, device="mlu")
value_cache_scale = torch.tensor(value_cache_scale_data, dtype=torch.float32, device="mlu")
min_blocks = (int)((tokens_num + block_size - 1) / block_size)
blocks_num = min(min_blocks + 10, 2 * min_blocks)
num_slots = blocks_num * block_size
slot_mapping = random.sample(range(num_slots), tokens_num)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, device="mlu")
slot_mapping[-1] = -1 # test mask
torch.manual_seed(0)
key_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (blocks_num, head_num, block_size, head_size), dtype=torch.int8).mlu()
#python base result
key_cache_base = key_cache.clone()
value_cache_base = value_cache.clone()
if only_quant_key:
value, value_cache_scale, value_cache_base, value_cache = None, None, None, None
self.op_impl_base(key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache_base, value_cache_base)
#mlu result
ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
#compute diff
baseline_key_cache = key_cache_base.cpu().flatten()
mlu_key_cache = key_cache.cpu().flatten()
key_cache_diff = (mlu_key_cache - baseline_key_cache).abs()
assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold"
if not only_quant_key:
baseline_value_cache = value_cache_base.cpu().flatten()
mlu_value_cache = value_cache.cpu().flatten()
value_cache_diff = (mlu_value_cache - baseline_value_cache).abs()
assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold"
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_large_tensor(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
print("offline_quant_to_paged_cache: test_large_tensor...")
head_num = 16
head_size = 128
token_nums = 200
block_size = 16
block_nums = ((2**32 - 1) // 1 // head_num // head_size // block_size)
num_slots = block_nums * block_size
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
#python base result
key_cache_base = key_cache.clone()
value_cache_base = value_cache.clone()
self.op_impl_base(key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache_base, value_cache_base)
#mlu result
ops.offline_quant_to_paged_cache(key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
#compute diff
baseline_key_cache = key_cache_base.cpu().flatten()
mlu_key_cache = key_cache.cpu().flatten()
key_cache_diff = (mlu_key_cache - baseline_key_cache).abs()
assert torch.max(key_cache_diff) < 2, "key_cache_diff exceed threshold"
baseline_value_cache = value_cache_base.cpu().flatten()
mlu_value_cache = value_cache.cpu().flatten()
value_cache_diff = (mlu_value_cache - baseline_value_cache).abs()
assert torch.max(value_cache_diff) < 2, "value_cache_diff exceed threshold"
block_nums = (2**32 // 1 // head_num // head_size // block_size)
num_slots = block_nums * block_size
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.offline_quant_to_paged_cache,
key, value, key_cache_scale, value_cache_scale, slot_mapping, key_cache, value_cache)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
print("offline_quant_to_paged_cache: test prevent...")
head_num = 16
head_size = 128
token_nums = 200
block_size = 16
block_nums = 20
num_slots = block_nums * block_size
dtype = torch.half
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
key = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
value = torch.randn(token_nums, head_num, head_size, dtype=dtype, device="mlu")
key_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
value_cache_scale = torch.randn(head_num, head_size, dtype=torch.float32, device="mlu")
key_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache = None
self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
value_cache = torch.zeros(block_nums, head_num, block_size, head_size, dtype=torch.int8, device="mlu")
value = value.reshape(token_nums, head_num, head_size, 1)
self.assertException("dim of v must be 3",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
value = value.squeeze()
value_cache = value_cache.reshape(block_nums, head_num, block_size, head_size, 1)
self.assertException("dim of v_cache must be 4",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
value_cache = value_cache.squeeze()
value_cache_scale = value_cache_scale.reshape(head_num, head_size, 1)
self.assertException("dim of v_cache_scale must be 2",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
value_cache_scale = value_cache_scale.squeeze()
value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, 1, head_size))
self.assertException("v last dim must be contiguous.",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
value = value.as_strided(size=(token_nums, head_num, head_size), stride=(head_num, token_nums, 1))
self.assertException("v second dim must be contiguous.",
ops.offline_quant_to_paged_cache, key, value, key_cache_scale, value_cache_scale,
slot_mapping, key_cache, value_cache)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "offline_quant_to_paged_cache not support MLU3XX device")
def test_inductor(self):
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.offline_quant_to_paged_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestOfflineQuantToPagedCache))

View File

@@ -0,0 +1,72 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as tmo
from common_utils import *
import random
class TestPerTokenSmoothQuantizeOp(BtTestCase):
def op_impl_base(self, *args):
x, smooth, zero, m_list = args
x_shape = x.size()
scale_shape = x.size()[0:-1]
if m_list is None:
smoothed = x * smooth
else:
input_list = x.split(tuple(m_list))
experts = len(input_list)
result = []
for i in range(experts):
result.append(input_list[i] * smooth[i])
smoothed = torch.concat(result, dim=0)
output, scale = QuantByRow(smoothed, 8)
return output.reshape(x_shape), scale.reshape(scale_shape)
def test_random_case(self):
torch.manual_seed(0)
case_list = set()
while(len(case_list) < 200):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
dtype = random.choice(dtype_list)
has_group = random.choice([False, True])
if has_group:
experts = random.randint(1, 40)
m_list = torch.randint(1, 100, (experts,), device="mlu", dtype=torch.int32)
ci = m_list.sum().item()
else:
experts = None
ci = random.randint(1, 4096)
co = random.randint(1, 4096)
case = (experts, ci, co, dtype)
if case in case_list:
continue
else:
case_list.add(case)
x = torch.randn(ci, co, device="mlu", dtype=dtype)
if has_group:
scale = torch.randn(experts, co, device="mlu", dtype=torch.float32)
else:
scale = torch.randn(co, device="mlu", dtype=torch.float32)
print("experts={}, ci={}, co={}, dtype={}, testing...".format(experts, ci, co, dtype), flush=True)
param = (x, scale, None, m_list if has_group else None)
tmo_output, tmo_scale = tmo.per_token_smooth_quantize(*param)
torch_output, torch_scale = self.op_impl_base(*param)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
def test_inductor(self):
m_list = torch.randint(1, 100, (8,), device="mlu", dtype=torch.int32)
total_m = m_list.sum().item()
x = torch.randn(total_m, 1024, device="mlu", dtype=torch.half)
scale = torch.randn(8, 1024, device="mlu", dtype=torch.float32)
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
output_scale = torch.empty(x.size()[:-1], dtype=torch.float32, device="mlu")
args = (x, scale, output, output_scale, None, m_list, None, None, 'per_token', True)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
if __name__ == '__main__':
exit(run_unittest(TestPerTokenSmoothQuantizeOp))

View File

@@ -0,0 +1,21 @@
import torch
import unittest
import torch_mlu_ops as ops
from common_utils import *
class TestPreloadOp(BtTestCase):
def op_impl_base(self, *args):
wegiht, size = args
return super().op_impl_base(*args)
def test_preload(self):
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
ops.preload(weight, weight.element_size() * weight.numel())
torch.mlu.synchronize()
def test_inductor(self):
weight = torch.randn((1024, 8, 5, 1024)).half().mlu()
self.base_opcheck(torch.ops.torch_mlu_ops.preload, (weight, weight.element_size() * weight.numel()))
if __name__ == '__main__':
exit(run_unittest(TestPreloadOp))

View File

@@ -0,0 +1,137 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from torch.nn import functional as F
from torch.nn.parameter import Parameter
active = torch.nn.functional.silu
def torch_ffn(input, w1, scale1, bias1, w2, scale2, bias2, is_gated):
tmp_input = input.flatten(0, -2).to(torch.float)
inner_size = w1.size(0) // (1 + is_gated)
imm = weight_only_quant_matmul(tmp_input, w1, scale1, bias1)
acted = active(imm[:, :inner_size])
acted = acted * imm[:, inner_size:] if is_gated else acted
out = weight_only_quant_matmul(acted, w2, scale2, bias2)
return out.reshape(input.shape)
def torch_smooth_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, input_smooth, act_smooth, is_gated):
inner_size = w1.size(0) // (1 + is_gated)
tmp_input = input.flatten(0, -2).to(torch.float)
quant_input, input_scale = QuantByRow(tmp_input.flatten(0, -2) * input_smooth, 8)
imm = smooth_quant_matmul(quant_input, input_scale, w1, scale1, input.dtype, bias1)
acted = active(imm[:, :inner_size])
acted = acted * imm[:, inner_size:] if is_gated else acted
quant_acted, acted_scale = QuantByRow(acted.flatten(0, -2) * act_smooth, 8)
out = smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, input.dtype, bias2)
return out.reshape(input.shape)
def tmo_weight_only_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
tmp_input = input.flatten(0, -2)
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'silu', quant_bit, True)
out = ops.weight_only_quant_matmul(imm, w2, scale2, None, bias2, None, "none", quant_bit)
return out.reshape(input.shape)
def tmo_weight_only_group_quant_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
tmp_input = input.flatten(0, -2)
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit)
acted = active(imm)
out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit)
return out.reshape(input.shape)
def tmo_weight_only_quant_gated_ffn(input, w1, scale1, bias1, w2, scale2, bias2, quant_bit):
tmp_input = input.flatten(0, -2)
imm = ops.weight_only_quant_matmul(tmp_input, w1, scale1, None, bias1, None, 'none', quant_bit)
acted = ops.active(imm, 'silu', True)
out = ops.weight_only_quant_matmul(acted, w2, scale2, None, bias2, None, "none", quant_bit)
return out.reshape(input.shape)
def tmo_pertoken_smooth_quant_gated_ffn(input, w1, scale1, bias13, w2, scale2, bias2, input_smooth, act_smooth, dtype):
tmp_input = input.flatten(0, -2)
quant_input, input_scale = ops.per_token_smooth_quantize(tmp_input, input_smooth, None)
imm = ops.smooth_quant_matmul(quant_input, input_scale, w1, scale1, dtype, bias13)
acted = ops.active(imm, 'silu', True)
quant_acted, acted_scale = ops.per_token_smooth_quantize(acted, act_smooth, None)
out = ops.smooth_quant_matmul(quant_acted, acted_scale, w2, scale2, dtype, bias2)
return out.reshape(input.shape)
def init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=torch.half, group_num=1):
sigma = 0.1
eps = 0.1 # Avoid the occurrence of nan
torch.manual_seed(1)
input_smooth = torch.randn(hidden_size, dtype=torch.float, device="mlu").abs() + eps
act_smooth = torch.randn(inner_size, dtype=torch.float, device="mlu").abs() + eps
w1 = torch.randn((1 + is_gated) * inner_size, hidden_size, dtype=dtype, device="mlu") * sigma
w1 = w1 / input_smooth if with_smooth else w1
bias1 = torch.randn((1 + is_gated) * inner_size, dtype=dtype, device="mlu") * sigma
w2 = torch.randn(hidden_size, inner_size, dtype=dtype, device="mlu") * sigma
w2 = w2 / act_smooth if with_smooth else w2
bias2 = torch.randn(hidden_size, dtype=dtype, device="mlu") * sigma
input = torch.randn(batch, seq, hidden_size, dtype=dtype, device="mlu")
quant_w1, scale1 = QuantByRow(w1, quant_bit, group_num)
quant_w2, scale2 = QuantByRow(w2, quant_bit, group_num)
return input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth
dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
class TestQuantFFN(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
def test_weight_only_quant_ffn(self):
for dtype in dtype_list:
print("test_weight_only_quant_ffn...")
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 8, False, False
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype)
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, False)
tmo_out = tmo_weight_only_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, quant_bit)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_weight_only_quant_gated_ffn(self):
for dtype in dtype_list:
print("test_weight_only_quant_gated_ffn...")
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 768, 4, True, False
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype)
quant_w1_int4 = PairlyPackInt8(quant_w1)
quant_w2_int4 = PairlyPackInt8(quant_w2)
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, True)
tmo_out = tmo_weight_only_quant_gated_ffn(input, quant_w1_int4, scale1, bias1, quant_w2_int4, scale2, bias2, quant_bit)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.005, use_MSE=True, use_RAE=True)
def test_weight_only_group_quant_ffn(self):
for dtype in dtype_list:
print("test_weight_only_group_quant_ffn...")
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth = 3, 5, 512, 512, 8, False, False
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, _, _ = init_tensors(batch, seq,
hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype=dtype, group_num=8)
torch_out = torch_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2, is_gated)
tmo_out = tmo_weight_only_group_quant_ffn(input, quant_w1, scale1.to(dtype=input.dtype), bias1,
quant_w2, scale2.to(dtype=input.dtype), bias2, quant_bit)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.05, use_MSE=True, use_RAE=True)
def test_pertoken_smooth_quant_ffn(self):
for dtype in dtype_list:
print("test_pertoken_smooth_quant_ffn...")
batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype = 3, 5, 512, 768, 8, True, True, dtype
input, quant_w1, scale1, quant_w2, scale2, bias1, bias2, input_smooth, act_smooth = \
init_tensors(batch, seq, hidden_size, inner_size, quant_bit, is_gated, with_smooth, dtype)
torch_out = torch_smooth_quant_ffn(input, quant_w1, scale1, bias1, quant_w2, scale2, bias2,
input_smooth, act_smooth, is_gated)
tmo_out = tmo_pertoken_smooth_quant_gated_ffn(input, quant_w1, scale1, bias1, quant_w2,
scale2, bias2, input_smooth, act_smooth, dtype)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.05, use_MSE=True, use_RAE=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestQuantFFN))

View File

@@ -0,0 +1,275 @@
import torch
import unittest
import torch_mlu_ops as ops
import numpy as np
import random
import math
from common_utils import *
from typing import Optional
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, quant_bit):
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype, True)
cache_scale = args[10]
args = args[0:10]
args.insert(4, cache_scale[0])
args.insert(5, cache_scale[1])
args.insert(12, quant_bit)
return args
class TestQuantToLinearCache(BtTestCase):
def op_impl_base(self, *args):
key, value, key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale, \
context_lengths, max_context_len, packed, context_seq_offset, cache_bs_id, cache_seqlen_offset, \
quant_bit = args
def quant(context: torch.Tensor,
quant_bit: int,
group_size: int):
# context:[head_num, seq, head_size]
head_num = context.shape[0]
head_size = context.shape[-1]
if group_size != head_size:
context = context.reshape(head_num, -1, group_size)
context_fp32 = context.to(torch.float32)
max_value, _ = torch.max(context_fp32.abs(), dim=-1, keepdim=True)
int_max = float(2 ** (quant_bit - 1) - 1)
scale = max_value / int_max
scaled_context = context_fp32 / scale
return scaled_context.reshape(head_num, -1, head_size), scale[..., 0]
batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0]
head_num = key.shape[-2]
head_size = key.shape[-1]
if key_cache_quant_scale.dim() == 3:
group_size = head_size
else:
group_size = head_size // key_cache_quant_scale.size(-1)
for i in range(batch_size):
if packed:
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i+1] - context_lengths[i]
context_seq_offset_i = 0
else:
key_i = key[i].transpose(1, 0)
value_i = value[i].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i]
context_seq_offset_i = context_seq_offset[i]
cache_bs_id_i = cache_bs_id[i]
cache_seqlen_offset_i = cache_seqlen_offset[i] if cache_seqlen_offset is not None else 0
if cache_bs_id_i < 0 or cache_seqlen_offset_i < 0:
continue
key_cache_i = \
key_cache[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
key_cache_scale_i = \
key_cache_quant_scale[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
# key_i[head_num, context_len[i], head_size]
key_i = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
float_key_i, key_scale_i = quant(key_i, quant_bit, group_size)
key_cache_scale_i = key_cache_scale_i.reshape(key_cache_scale_i.shape[0], -1)
key_cache_scale_i[...] = key_scale_i
rounded = torch.round(float_key_i)
clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1)
quant_key_i = clipped.to(torch.int8)
if quant_bit == 4:
quant_key_flat = quant_key_i.flatten()
d0 = quant_key_flat[0::2].to(torch.uint8)
d1 = quant_key_flat[1::2].to(torch.uint8)
dp = (d1 << 4) + (d0 & 0x0F)
quant_key_i = dp.to(torch.int8).reshape(head_num, -1, head_size // 2)
key_cache_i[...] = quant_key_i
if value_cache is not None and value is not None:
value_cache_scale_i = \
value_cache_quant_scale[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
value_i = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
float_value_i, value_scale_i = quant(value_i, quant_bit, group_size)
value_cache_scale_i = value_cache_scale_i.reshape(value_cache_scale_i.shape[0], -1)
value_cache_scale_i[...] = value_scale_i
rounded = torch.round(float_value_i)
clipped = torch.clip(rounded, -2 ** (quant_bit - 1), 2 ** (quant_bit - 1) - 1)
quant_value_i = clipped.to(torch.int8)
if quant_bit == 8:
value_cache_i = \
value_cache[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
else:
value_cache_i = \
value_cache[cache_bs_id_i, :, \
cache_seqlen_offset_i // 2:math.ceil((cache_seqlen_offset_i + context_len_i) / 2)]
if cache_seqlen_offset_i % 2 == 1:
front_vec = value_cache[cache_bs_id_i, :, cache_seqlen_offset_i // 2, :]
front_low_bits = front_vec & (0x0F)
front_low_bits_expand = front_low_bits.unsqueeze(1) # [head_num, 1, head_size]
quant_value_i = torch.cat((front_low_bits_expand, quant_value_i), dim=1)
if (cache_seqlen_offset_i + context_len_i) % 2 == 1:
back_vec = value_cache[cache_bs_id_i, :, math.ceil((cache_seqlen_offset_i + context_len_i) / 2) - 1, :]
back_high_bits = (back_vec >> 4) & (0x0F)
back_high_bits_expand = back_high_bits.unsqueeze(1) # [head_num, 1, head_size]
quant_value_i = torch.cat((quant_value_i, back_high_bits_expand), dim=1)
value_temp = quant_value_i.reshape(head_num, -1, 2, head_size)
quant_value_flat = value_temp.permute(0, 1, 3, 2).flatten()
v0 = quant_value_flat[0::2].to(torch.uint8)
v1 = quant_value_flat[1::2].to(torch.uint8)
vp = (v1 << 4) + (v0 & 0x0F)
quant_value_i = vp.to(torch.int8).reshape(head_num, -1, head_size)
value_cache_i[...] = quant_value_i
return (key_cache, value_cache, key_cache_quant_scale, value_cache_quant_scale) if value_cache is not None else (key_cache, key_cache_quant_scale)
def int8_to_int4(self,
input):
input_flat = input.flatten()
size = input_flat.size(0)
output = torch.zeros(size * 2, dtype=torch.int8, device=input.device)
high = input_flat >> 4
low = input_flat << 4
low = low >> 4
output[0::2] = low
output[1::2] = high
return output
def test_quant_to_linear_cache(self):
test_cases = 100
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
head_size_list *= 16
cache_memory_len_list = torch.randint(low=8, high=512, size=(test_cases, ), dtype=torch.int32)
cache_memory_len_list = cache_memory_len_list * 2
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=3, size=(test_cases, ), dtype=torch.int32)
dtype_map = [torch.half, torch.bfloat16, torch.float]
quant_bit_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
quant_bit_map = [4, 8]
for i in range(test_cases):
batch_size = bs_list[i].item()
invalid_batch = batch_size // 10
num_heads = num_heads_list[i].item()
head_size = head_size_list[i].item()
group_size_factors = [i for i in range(4, head_size + 1) if head_size % i == 0] # group_size should > 1
group_size = random.choice(group_size_factors)
cache_memory_len = cache_memory_len_list[i].item()
packed = packed_list[i].item()
dtype = dtype_map[dtype_list[i]]
quant_bit = quant_bit_map[quant_bit_list[i]]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and dtype == torch.bfloat16:
dtype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
print("case num={}, batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={}, quant_bit={}, group_size={} testing...".format(
i, batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype, quant_bit, group_size))
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 4,
dtype=torch.int32, device='mlu')
context_lens = context_lens * 2
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch_size, ),
low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=0,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
key = torch.randn((total_seqlen, num_heads, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((total_seqlen, num_heads, head_size),
dtype=torch.float, device='mlu')
else:
key = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
dtype=torch.float, device='mlu')
value = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
dtype=torch.float, device='mlu')
key = key.to(dtype)
value = value.to(dtype)
if quant_bit == 8 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
if quant_bit == 8 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size == head_size:
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len), dtype=torch.float32).mlu()
if quant_bit == 4 and group_size != head_size:
key_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len, head_size // 2), dtype=torch.int8).mlu()
value_cache = torch.randint(-128, 127, (max_bs, num_heads, cache_memory_len // 2, head_size), dtype=torch.int8).mlu()
key_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
value_cache_scale = torch.randn((max_bs, num_heads, cache_memory_len, head_size // group_size), dtype=torch.float32).mlu()
ref_key_cache = key_cache.clone()
ref_value_cache = value_cache.clone()
ref_key_cache_scale = key_cache_scale.clone()
ref_value_cache_scale = value_cache_scale.clone()
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
if packed > 0:
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale,
ref_value_cache_scale, cu_context_lens, max_context_len,
packed > 0, None, cache_bs_id, cache_seq_offsets,
quant_bit)
ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, cu_context_lens, max_context_len,
packed > 0, None, cache_bs_id, cache_seq_offsets,
quant_bit)
else:
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, ref_key_cache_scale,
ref_value_cache_scale, context_lens, max_context_len,
packed > 0, context_seq_offsets, cache_bs_id,
cache_seq_offsets, quant_bit)
ops.quant_to_linear_cache(key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, context_lens, max_context_len,
packed > 0, context_seq_offsets, cache_bs_id,
cache_seq_offsets, quant_bit)
if quant_bit == 8:
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0.003,
"key_cache must equal ref_key_cache", True, True, True, True)
self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0.003,
"value_cache must equal ref_value_cache", True, True, True, True)
else:
key_cache_int8 = self.int8_to_int4(key_cache)
ref_key_cache_int8 = self.int8_to_int4(ref_key_cache)
value_cache_int8 = self.int8_to_int4(value_cache)
ref_value_cache_int8 = self.int8_to_int4(ref_value_cache)
key_cache_diff = (key_cache_int8.cpu() - ref_key_cache_int8.cpu()).abs()
assert torch.max(key_cache_diff) < 2, "ref_key_cache must equal key_cache or absolute values differ by 1 due to round_mode!"
value_cache_diff = (value_cache_int8.cpu() - ref_value_cache_int8.cpu()).abs()
assert torch.max(value_cache_diff) < 2, "ref_value_cache must equal value_cache or absolute values differ by 1 due to round_mode!"
self.assertTensorsEqual(key_cache_scale.cpu().float(), ref_key_cache_scale.cpu().float(), 0.003,
"key_cache_scale must equal ref_key_cache_scale", True, True, True, True)
self.assertTensorsEqual(value_cache_scale.cpu().float(), ref_value_cache_scale.cpu().float(), 0.003,
"value_cache_scale must equal ref_value_cache_scale", True, True, True, True)
def test_inductor(self):
batch_size, num_heads, head_size, cache_memory_len, dtype, quant_bit = 4, 8, 64, 128, torch.float16, 8
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args)
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype, quant_bit)
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_linear_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestQuantToLinearCache))

View File

@@ -0,0 +1,213 @@
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
import copy
class TestQuantToPagedCache(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
k = create_tensor_from_dic(dic['k'], is_uniform=True, low=-1, high=1)
v = create_tensor_from_dic(dic['v'], is_uniform=True, low=-0.25, high=0.25)
k_cache = create_tensor_from_dic(dic['k_cache'])
v_cache = create_tensor_from_dic(dic['v_cache'])
k_cache_quant_scale = create_tensor_from_dic(dic['k_cache_quant_scale'])
v_cache_quant_scale = create_tensor_from_dic(dic['v_cache_quant_scale'])
slot_mapping = dic['slot_mapping']['data']
self.launch(k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping)
def launch(self, *args):
v = args[1]
args_bak = copy.deepcopy(args)
torch_out = self.op_impl_base(*args_bak)
tmo_out = ops.quant_to_paged_cache(*args)
self.assertTensorsEqual(torch_out[0].cpu().float(), tmo_out[0].cpu().float(), 9e-3,
use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_out[1].cpu().float(), tmo_out[1].cpu().float(), 9e-3,
use_MSE=True, use_RAE=True)
if v is not None:
self.assertTensorsEqual(torch_out[2].cpu().float(), tmo_out[2].cpu().float(), 3e-3,
use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_out[3].cpu().float(), tmo_out[3].cpu().float(), 3e-3,
use_MSE=True, use_RAE=True)
def op_impl_base(self, *args):
k, v, k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale, slot_mapping = args
# input_fp[head_num, head_size]
def quant(input_fp: torch.Tensor):
input_fp32 = input_fp.to(torch.float32)
max_value, _ = torch.max(input_fp32.abs(), dim=-1, keepdim=True)
scale = max_value / 127.0
scaled_input = input_fp32 / scale
return scaled_input.to(torch.int8), scale[..., 0]
tokens_num = k.shape[0] # [token_num, head_num, head_size]
block_size = k_cache.shape[2]
for i in range(tokens_num):
if slot_mapping[i] >= 0:
key_i = k[i]
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
key_cache_i = k_cache[block_id, :, block_offset, :]
key_cache_scale_i = k_cache_quant_scale[block_id, :, block_offset]
quant_key_i, key_scale_i = quant(key_i)
key_cache_i[...] = quant_key_i
key_cache_scale_i[...] = key_scale_i
if v is not None:
value_i = v[i]
value_cache_i = v_cache[block_id, :, block_offset, :]
value_cache_scale_i = v_cache_quant_scale[block_id, :, block_offset]
quant_value_i, value_scale_i = quant(value_i)
value_cache_i[...] = quant_value_i
value_cache_scale_i[...] = value_scale_i
return (k_cache, v_cache, k_cache_quant_scale, v_cache_quant_scale) if v is not None else (k_cache, k_cache_quant_scale)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
def test_quant_to_paged_cache(self):
token_nums = random.randint(1, 2048)
head_num_kv = random.randint(1, 128)
head_size = random.randint(1, 1024)
block_size = random.randint(1, 50)
min_blocks = (int)((token_nums + block_size - 1) / block_size)
block_nums = min(min_blocks + 10, 2 * min_blocks)
num_slots = block_nums * block_size
slot_mapping = random.sample(range(num_slots), token_nums)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1 # test mask
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
only_quant_key_list = [True, False]
for _ in range(100):
print("test_quant_to_paged_cache...")
dtype = random.choice(dtype_list)
only_quant_key = random.choice(only_quant_key_list)
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
if only_quant_key:
value, value_cache, value_cache_scale = None, None, None
else:
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
value_cache = torch.zeros_like(key_cache)
value_cache_scale = torch.zeros_like(key_cache_scale)
self.launch(key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_large_tensor(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
print('quant_to_paged_cache: test_large_tensor')
head_num_kv = 16
head_size = 128
token_nums = 20
block_size = 16
block_nums = ((2**32 - 1) // 1 // head_num_kv // head_size // block_size)
num_slots = block_nums * block_size
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache_torch = torch.zeros_like(key_cache_torch)
key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch)
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
key_cache_tmo = torch.zeros_like(key_cache_torch)
value_cache_tmo = torch.zeros_like(value_cache_torch)
key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch)
value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch)
self.op_impl_base(key, value, key_cache_torch, value_cache_torch, key_cache_scale_torch, value_cache_scale_torch, slot_mapping)
ops.quant_to_paged_cache(key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping)
self.assertTensorsEqual(key_cache_torch.cpu(), key_cache_tmo.cpu(), 1)
self.assertTensorsEqual(value_cache_torch.cpu(), value_cache_tmo.cpu(), 1)
self.assertTensorsEqual(key_cache_scale_torch.cpu().float(), key_cache_scale_tmo.cpu().float(), 3e-3,
use_MSE=True, use_RAE=True)
self.assertTensorsEqual(value_cache_scale_torch.cpu().float(), value_cache_scale_tmo.cpu().float(), 3e-3,
use_MSE=True, use_RAE=True)
block_nums = (2**32 // 1 // head_num_kv // head_size // block_size)
num_slots = block_nums * block_size
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_nums]
key_cache_torch = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache_torch = torch.zeros_like(key_cache_torch)
key_cache_scale_torch = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
value_cache_scale_torch = torch.zeros_like(key_cache_scale_torch)
key_cache_tmo = torch.zeros_like(key_cache_torch)
value_cache_tmo = torch.zeros_like(value_cache_torch)
key_cache_scale_tmo = torch.zeros_like(key_cache_scale_torch)
value_cache_scale_tmo = torch.zeros_like(value_cache_scale_torch)
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.quant_to_paged_cache,
key, value, key_cache_tmo, value_cache_tmo, key_cache_scale_tmo, value_cache_scale_tmo, slot_mapping)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
token_nums = random.randint(1, 2048)
head_num_kv = random.randint(1, 128)
head_size = random.randint(1, 1024)
block_size = random.randint(1, 50)
min_blocks = (int)((token_nums + block_size - 1) / block_size)
block_nums = min(min_blocks + 10, 2 * min_blocks)
num_slots = block_nums * block_size
slot_mapping = random.sample(range(num_slots), token_nums)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1 # test mask
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
dtype = random.choice(dtype_list)
print("quant_to_paged_cache: test_prevent...")
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
value = None
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache = torch.zeros_like(key_cache)
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
value_cache_scale = torch.zeros_like(key_cache_scale)
self.assertException("v.has_value() == v_cache.has_value() && v.has_value() == v_cache_scale.has_value().",
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, slot_mapping)
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(1, token_nums, token_nums * head_num_kv))
self.assertException("v last dim must be contiguous.",
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, slot_mapping)
value = value.as_strided(size=(token_nums, head_num_kv, head_size), stride=(head_size, head_num_kv, 1))
self.assertException("v second dim must be contiguous.",
ops.quant_to_paged_cache, key, value, key_cache, value_cache, key_cache_scale,
value_cache_scale, slot_mapping)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "quant_to_paged_cache not support MLU3XX device")
def test_inductor(self):
token_nums = 20
head_num_kv = 30
head_size = 20
block_size = 10
min_blocks = (int)((token_nums + block_size - 1) / block_size)
block_nums = min(min_blocks + 10, 2 * min_blocks)
num_slots = block_nums * block_size
slot_mapping = random.sample(range(num_slots), token_nums)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int).mlu()
slot_mapping[-1] = -1 # test mask
dtype = torch.half
key = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-1, 1)
value = torch.randn(token_nums, head_num_kv, head_size, dtype=dtype, device="mlu").uniform_(-0.25, 0.25)
key_cache = torch.zeros(block_nums, head_num_kv, block_size, head_size, dtype=torch.int8, device="mlu")
value_cache = torch.zeros_like(key_cache)
key_cache_scale = torch.zeros(block_nums, head_num_kv, block_size, dtype=torch.float, device="mlu")
value_cache_scale = torch.zeros_like(key_cache_scale)
args = (key, value, key_cache, value_cache, key_cache_scale, value_cache_scale, slot_mapping)
self.base_opcheck(torch.ops.torch_mlu_ops.quant_to_paged_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestQuantToPagedCache))

View File

@@ -0,0 +1,46 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as tmo
from common_utils import *
import random
class TestQuantizeOp(BtTestCase):
def op_impl_base(self, *args):
x, smooth, zero = args
return (x * smooth).round().clamp(-128.0, 127.0).to(torch.int8)
def test_random_case(self):
torch.manual_seed(0)
case_list = set()
while(len(case_list) < 100):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
dtype = random.choice(dtype_list)
ci = random.randint(1, 4096)
co = random.randint(1, 4096)
case = (ci, co)
if case in case_list:
continue
else:
case_list.add((ci, co))
x = torch.randn(ci, co, device="mlu", dtype=dtype)
scale = torch.randn(co, device="mlu", dtype=torch.float32)
print("ci={}, co={}, dtype={}, testing...".format(ci, co, dtype), flush=True)
param = (x, scale, None)
tmo_output = tmo.quantize(*param)
torch_output = self.op_impl_base(*param)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(), 0.01, use_MSE=True, use_RAE=True)
def test_inductor(self):
x = torch.randn(16,128, 1024, device="mlu", dtype=torch.half)
scale = torch.randn(1024, device="mlu", dtype=torch.float32)
output = torch.empty(x.size(), dtype=torch.int8, device="mlu")
args = (x, scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
if __name__ == '__main__':
exit(run_unittest(TestQuantizeOp))

View File

@@ -0,0 +1,179 @@
import torch
import unittest
import torch_mlu_ops as ops
import random
from common_utils import *
from typing import Optional
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype)
return args[0:10]
class TestReshapeLinearCache(BtTestCase):
def op_impl_base(self, *args):
key, value, key_cache, value_cache, context_lengths, max_context_len, packed, \
context_seq_offset, cache_bs_id, cache_seqlen_offset = args
batch_size = context_lengths.shape[0] - 1 if packed else context_lengths.shape[0]
for i in range(batch_size):
if packed:
key_i = key[context_lengths[i]:context_lengths[i+1]].transpose(1, 0)
value_i = value[context_lengths[i]:context_lengths[i+1]].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i+1] - context_lengths[i]
context_seq_offset_i = 0
else:
key_i = key[i].transpose(1, 0)
value_i = value[i].transpose(1, 0) if value is not None else None
context_len_i = context_lengths[i]
context_seq_offset_i = context_seq_offset[i]
cache_bs_id_i = cache_bs_id[i]
cache_seqlen_offset_i = cache_seqlen_offset[i]
if cache_seqlen_offset_i < 0 or cache_bs_id_i < 0:
continue
key_cache_i = \
key_cache[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
key_cache_i[...] = key_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
if value_cache is not None and value is not None:
value_cache_i = \
value_cache[cache_bs_id_i, :, \
cache_seqlen_offset_i:cache_seqlen_offset_i + context_len_i]
value_cache_i[...] = value_i[:, context_seq_offset_i:context_seq_offset_i + context_len_i]
return (key_cache, value_cache, ) if value_cache is not None else key_cache
def test_reshape_linear_cache(self):
test_cases = 100
bs_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
num_heads_list = torch.randint(low=1, high=32, size=(test_cases, ), dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
head_size_list *= 16
cache_memory_len_list = torch.randint(low=16, high=1024, size=(test_cases, ), dtype=torch.int32)
packed_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=4, size=(test_cases, ), dtype=torch.int32)
dtype_map = [torch.int8, torch.half, torch.bfloat16, torch.float]
for i in range(test_cases):
q_heads = 1
batch_size = bs_list[i].item()
invalid_batch = batch_size // 10
num_heads = num_heads_list[i].item()
head_size = head_size_list[i].item()
cache_memory_len = cache_memory_len_list[i].item()
packed = packed_list[i].item()
total_heads = q_heads + num_heads * 2
dtype = dtype_map[dtype_list[i]]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name and dtype == torch.bfloat16:
dtype = torch.half
print("BFLOAT16 is not support on {}, use half instead".format(mlu_name))
print("batch_size={}, num_heads={}, head_size={}, cache_memory_len={}, packed={}, dtype={} testing...".format(
batch_size, num_heads, head_size, cache_memory_len, packed > 0, dtype))
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 2,
dtype=torch.int32, device='mlu')
# print(context_lens)
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cache_seq_offsets[random.sample([*range(0, batch_size)], invalid_batch)] = -1
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
total_seqlen = cu_context_lens[-1]
if packed > 0:
context = torch.randn((total_seqlen, total_heads, head_size),
dtype=torch.float, device='mlu')
else:
context = torch.randn((batch_size, max_context_len + max_seq_offset, total_heads, head_size),
dtype=torch.float, device='mlu')
cache = torch.randn((2, max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
context = context.to(dtype)
cache = cache.to(dtype)
ref_cache = cache.clone()
key = context[..., q_heads:q_heads + num_heads, :]
value = context[..., q_heads + num_heads:q_heads + 2 * num_heads, :]
key_cache = cache[0]
value_cache = cache[1]
ref_key_cache = ref_cache[0]
ref_value_cache = ref_cache[1]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
cache_bs_id[random.sample([*range(0, batch_size)], invalid_batch)] = -1 # set invalid batch
if packed > 0:
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, cu_context_lens,
max_context_len, packed > 0, None,
cache_bs_id, cache_seq_offsets)
ops.reshape_linear_cache(key, value, key_cache, value_cache, cu_context_lens,
max_context_len, packed > 0, None,
cache_bs_id, cache_seq_offsets)
else:
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, context_lens,
max_context_len, packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets)
ops.reshape_linear_cache(key, value, key_cache, value_cache, context_lens,
max_context_len, packed > 0, context_seq_offsets,
cache_bs_id, cache_seq_offsets)
self.assertTensorsEqual(cache.cpu().float(), ref_cache.cpu().float(), 0, "ref_cache must equal cache",
True, True, True, True)
def test_reshape_linear_key_cache(self):
batch_size, num_heads, head_size, cache_memory_len = 2, 2, 16, 128
print("[test_reshape_linear_key_cache] batch_size={}, num_heads={}, head_size={}, cache_memory_len={} testing...".format(
batch_size, num_heads, head_size, cache_memory_len))
max_bs = batch_size + 1
context_lens = torch.randint(size=(batch_size, ), low=1,
high=cache_memory_len // 2,
dtype=torch.int32, device='mlu')
max_context_len = context_lens.max().item()
max_seq_offset = max_context_len // 3 + 1
context_seq_offsets = torch.randint(size=(batch_size, ), low=0, high=max_seq_offset,
dtype=torch.int32, device='mlu')
cache_seq_offsets = torch.randint(size=(batch_size, ), low=-1,
high=(cache_memory_len - max_context_len) // 3 + 1,
dtype=torch.int32, device='mlu')
cu_context_lens = torch.cumsum(context_lens, dim=-1)
cu_context_lens = torch.nn.functional.pad(cu_context_lens, (1,0), "constant", 0).to(torch.int32)
context = torch.randn((batch_size, max_context_len + max_seq_offset, num_heads, head_size),
dtype=torch.float, device='mlu')
key_cache = torch.randn((max_bs, num_heads, cache_memory_len, head_size), dtype=torch.float, device='mlu')
ref_key_cache = key_cache.clone()
key = context[..., :]
cache_bs_id = None
cache_bs_id = random.sample([*range(0, max_bs)], batch_size)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu()
self.op_impl_base(key, None, ref_key_cache, None, context_lens,
max_context_len, False, context_seq_offsets,
cache_bs_id, cache_seq_offsets)
ops.reshape_linear_cache(key, None, key_cache, None, context_lens,
max_context_len, False, context_seq_offsets,
cache_bs_id, cache_seq_offsets)
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, "ref_cache must equal cache",
True, True, True, True)
def test_inductor(self):
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args)
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 0, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_linear_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestReshapeLinearCache))

View File

@@ -0,0 +1,141 @@
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
def gen_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype):
args = generate_cache_func_args(batch_size, num_heads, head_size, cache_memory_len, packed, dtype)
slot_mapping = args[11]
args = args[0:4] + [slot_mapping]
return args
class TestReshapePagedCacheOp(BtTestCase):
def op_impl_base(self, *args):
k, v, k_cache, v_cache, slot_mapping = args
num_tokens = k.shape[0]
block_size = k_cache.shape[2]
for i in range(num_tokens):
if slot_mapping[i] >= 0:
block_id = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
k_cache[block_id, :, block_offset, :] = k[i]
if v is not None:
v_cache[block_id, :, block_offset, :] = v[i]
return (k_cache, v_cache) if v is not None else k_cache
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
def test_reshape_paged_cache(self):
test_cases = 100
num_tokens_list = torch.randint(low=1, high=1024, size=(test_cases, ), dtype=torch.int32)
num_heads_list = torch.randint(low=1, high=64, size=(test_cases, ), dtype=torch.int32)
head_size_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
head_size_list *= 16
block_size_list = torch.randint(low=1, high=4, size=(test_cases, ), dtype=torch.int32)
block_size_list *= 16
only_reshape_key_list = [True, False]
for i in range(test_cases):
num_tokens = num_tokens_list[i]
num_heads = num_heads_list[i]
head_size = head_size_list[i]
block_size = block_size_list[i]
min_blocks = (num_tokens + block_size - 1) // block_size
num_blocks = min(min_blocks + 10, 2 * min_blocks)
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
only_reshape_key = random.choice(only_reshape_key_list)
for dtype in dtype_list:
print("num_tokens: {}, num_heads: {}, head_size: {}, num_blocks: {}, block_size: {}, testing...".format(
num_tokens, num_heads, head_size, num_blocks, block_size), flush=True)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device="mlu")
_, key, value = qkv.unbind(dim=1)
key_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu")
value_cache = torch.randn(num_blocks, num_heads, block_size, head_size, dtype=dtype, device="mlu")
num_slots = num_blocks * block_size
slot_mapping = random.sample(range(num_slots.item()), num_tokens.item())
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu")
slot_mapping[-1] = -1
ref_key_cache, ref_value_cache = key_cache.clone(), value_cache.clone()
if only_reshape_key:
value, ref_value_cache, value_cache = None, None, None
self.op_impl_base(key, value, ref_key_cache, ref_value_cache, slot_mapping)
ops.reshape_paged_cache(key, value, key_cache, value_cache, slot_mapping)
self.assertTensorsEqual(key_cache.cpu().float(), ref_key_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
if not only_reshape_key:
self.assertTensorsEqual(value_cache.cpu().float(), ref_value_cache.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_large_tensor(self):
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
print("reshape_paged_cache: test_large_tensor")
if dtype == torch.float:
dtype_size = 4
elif dtype == torch.half or dtype == torch.bfloat16:
dtype_size = 2
head_num = 16
head_size = 128
token_num = 20
block_size = 16
block_num = ((2**32 - 1) // dtype_size // head_num // head_size // block_size)
k = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu")
v = torch.randn(token_num, head_num, head_size, dtype=dtype, device="mlu")
k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
num_slots = block_num * block_size
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num]
ref_key_cache, ref_value_cache = k_cache.clone(), v_cache.clone()
self.op_impl_base(k, v, ref_key_cache, ref_value_cache, slot_mapping)
ops.reshape_paged_cache(k, v, k_cache, v_cache, slot_mapping)
for i in range(block_size):
self.assertTensorsEqual(k_cache[:, :, i, :].cpu().float(), ref_key_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
self.assertTensorsEqual(v_cache[:, :, i, :].cpu().float(), ref_value_cache[:, :, i, :].cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
block_num = (2**32 // dtype_size // head_num // head_size // block_size)
k_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
v_cache = torch.randn(block_num, head_num, block_size, head_size, dtype=dtype, device="mlu")
num_slots = block_num * block_size
slot_mapping = torch.randperm(num_slots, dtype=torch.int32, device="mlu")[:token_num]
self.assertException("The addressing range of kv_cache cannot exceed 4G.", ops.reshape_paged_cache,
k, v, k_cache, v_cache, slot_mapping)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
k = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu")
v = torch.randn(1024, 8, 128, dtype=torch.half, device="mlu")
k_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu")
v_cache = None
slot_mapping = random.sample(range(1024 * 4), 1024)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="mlu")
self.assertException("v.has_value() == v_cache.has_value().", ops.reshape_paged_cache,
k, v, k_cache, v_cache, slot_mapping)
v_cache = torch.randn(1024, 8, 4, 128, dtype=torch.half, device="mlu")
v_cache = v_cache.as_strided(size=(1024, 8, 4, 128), stride=(8*4*128, 4* 128, 127, 1))
self.assertException("v_cache need be contiguous.", ops.reshape_paged_cache,
k, v, k_cache, v_cache, slot_mapping)
v_cache = v_cache.contiguous()
v = v.as_strided(size=(1024, 8, 128), stride=(1, 1024, 1024 * 8))
self.assertException("v last dim must be contiguous.", ops.reshape_paged_cache,
k, v, k_cache, v_cache, slot_mapping)
v = v.as_strided(size=(1024, 8, 128), stride=(1024, 8, 1))
self.assertException("v second dim must be contiguous.", ops.reshape_paged_cache,
k, v, k_cache, v_cache, slot_mapping)
@unittest.skipIf('MLU3' in torch.mlu.get_device_name(), "reshape_paged_cache not support MLU3XX device")
def test_inductor(self):
batch_size, num_heads, head_size, cache_memory_len, dtype = 4, 8, 64, 128, torch.float16
args = gen_args(batch_size, num_heads, head_size, cache_memory_len, 1, dtype)
self.base_opcheck(torch.ops.torch_mlu_ops.reshape_paged_cache, args)
if __name__ == '__main__':
exit(run_unittest(TestReshapePagedCacheOp))

View File

@@ -0,0 +1,158 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import torch.nn as nn
class SelfAttn(torch.nn.Module):
def __init__(
self,
qkv_weights,
qkv_biass,
o_weight,
o_bias,
norm_weight,
norm_bias,
input_size,
head_size,
query_factor,
eps = 1e-5
) -> None:
super().__init__()
assert len(qkv_weights) == 3 and len(qkv_biass) == 3, 'length of weights and biass must be 3'
self.layernorm = torch.nn.LayerNorm(input_size)
self.layernorm.eps = eps
self.layernorm.weight = nn.Parameter(norm_weight)
self.layernorm.bias = nn.Parameter(norm_bias)
self.weights = qkv_weights
self.biass = qkv_biass
self.o_weight = o_weight
self.o_bias = o_bias
self.head_size = head_size
self.head_num = self.weights[0].size(0) // head_size
self.query_factor = query_factor
def forward(self, input: torch.Tensor):
n = input.size(0)
t = input.size(1)
normed_input = self.layernorm(input)
q = F.linear(normed_input, nn.Parameter(self.weights[0]), nn.Parameter(self.biass[0])).view(n, t, self.head_num, self.head_size)
k = F.linear(normed_input, nn.Parameter(self.weights[1]), nn.Parameter(self.biass[1])).view(n, t, self.head_num, self.head_size)
v = F.linear(normed_input, nn.Parameter(self.weights[2]), nn.Parameter(self.biass[2])).view(n, t, self.head_num, self.head_size)
qk = torch.einsum('bthd,bshd->bhts', q, k) * self.query_factor
attn = torch.softmax(qk, dim=-1, dtype=v.dtype)
qkv = torch.einsum('bhts,bshd->bthd', attn, v).reshape(n, t, -1)
output = F.linear(qkv, nn.Parameter(self.o_weight), nn.Parameter(self.o_bias)) + input
return output
class BTSelfAttn(torch.nn.Module):
def __init__(
self,
qkv_weights,
qkv_biass,
o_weight,
o_bias,
norm_weight,
norm_bias,
head_size,
query_factor,
eps = 1e-5
) -> None:
super().__init__()
self.weights = qkv_weights
self.biass = qkv_biass
self.o_weight = o_weight
self.o_bias = o_bias
self.norm_weight = norm_weight
self.norm_bias = norm_bias
self.head_size = head_size
self.query_factor = query_factor
self.eps = eps
def forward(self, input: torch.Tensor):
n, t = input.size(0), input.size(1)
q, k, v = ops.fused_norm_attention_project(input,
self.weights[0],
self.biass[0],
self.weights[1],
self.biass[1],
self.weights[2],
self.biass[2],
self.norm_weight,
self.norm_bias,
self.eps,
"nhtc",
self.head_size,
False)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = ops.flash_attention(q, k, v, None, None, None, None, None, t, t,
self.query_factor, False, -1, -1, q.dtype).flatten(-2, -1)
output = ops.attention_project(attn_out, self.o_weight, self.o_bias, input)
return output
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
class TestSelfAttn(BtTestCase):
def op_impl_base(self, *args):
return super().op_impl_base(*args)
def test_self_attn(self):
N, T, input_size, hidden_size, head_size, eps, query_factor = 5, 2048, 512, 768, 64, 1e-5, 0.125
for dtype in dtype_list:
print("N: {}, T: {}, input_size: {}, hidden_size: {}, testing...".format(
N, T, input_size, hidden_size, dtype), flush=True)
input = torch.randn(N, T, input_size, dtype=dtype, device="mlu")
weight = torch.randn(hidden_size * 3, input_size, dtype=dtype, device="mlu") / 10
bias = torch.randn(hidden_size * 3, dtype=dtype, device="mlu")
norm_weight = torch.randn(input_size, dtype=dtype, device="mlu")
norm_bias = torch.randn(input_size, dtype=dtype, device="mlu")
weights = torch.chunk(weight, 3)
biass = torch.chunk(bias, 3)
o_weight = torch.randn(input_size, hidden_size, dtype=dtype, device="mlu") / 10
o_bias = torch.randn(input_size, dtype=dtype, device="mlu")
torch_self_attn = SelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias,
input_size, head_size, query_factor, eps)
tmo_self_attn = BTSelfAttn(weights, biass, o_weight, o_bias, norm_weight, norm_bias,
head_size, query_factor, eps)
# test self_attn
print("test self_attn...")
torch_out = torch_self_attn(input)
tmo_out = tmo_self_attn(input)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.011, use_MSE=True, use_RAE=True)
def test_sd_attention(self):
N, Tq, Tk, hq, hk, head_size, query_factor = 2, 4096, 77, 8, 8, 40, 0.125
for dtype in dtype_list:
q = torch.randn(N, Tq, hq, head_size, dtype=dtype, device="mlu")
k = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu")
v = torch.randn(N, Tk, hk, head_size, dtype=dtype, device="mlu")
qk = torch.einsum('bthd,bshd->bhts', q, k) * query_factor
attn = torch.softmax(qk, dim=-1, dtype=v.dtype)
torch_out = torch.einsum('bhts,bshd->bthd', attn, v).reshape(N, Tq, -1)
qt = q.transpose(1, 2).contiguous()
kt = k.transpose(1, 2).contiguous()
vt = v.transpose(1, 2).contiguous()
tmo_out = ops.flash_attention(qt.transpose(1, 2),
kt.transpose(1, 2),
vt.transpose(1, 2),
None, None, None, None, None,
Tq, Tk, query_factor, False).flatten(-2, -1)
self.assertTensorsEqual(torch_out.cpu().float(), tmo_out.cpu().float(),
0.004, use_MSE=True, use_RAE=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestSelfAttn))

View File

@@ -0,0 +1,348 @@
import math
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
class TestSessionCacheAttnOp(BtTestCase):
def op_impl_base(self, *args):
def dequant_from_cache(key_cache, value_cache, key_cache_scale, value_cache_scale, cache_bs_id, cache_lens,
cache_seq_offset, quant_bit, quant_layout, max_cache_len):
batch = cache_bs_id.size(0)
head_num = key_cache.size(1)
head_size = value_cache.size(-1)
cache_len = key_cache.size(-2)
cache_shape = (batch, max_cache_len, head_num, head_size)
key_cache_mem = torch.zeros(cache_shape, dtype=torch.float, device='mlu')
value_cache_mem = torch.zeros_like(key_cache_mem)
for i in range(batch):
batch_id = cache_bs_id[i]
cache_offset = 0 if cache_seq_offset is None else cache_seq_offset[i]
key_cache_data = key_cache[batch_id] # head_num_kv, cache_mem_len, head_size
value_cache_data = value_cache[batch_id]
if quant_bit == 4:
key_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu")
value_quant_data_int8 = torch.zeros(head_num, cache_len, head_size, dtype=torch.int8, device="mlu")
key_quant_data_int8[...,::2] = key_cache_data << 4 >> 4
key_quant_data_int8[...,1::2] = key_cache_data >> 4
key_quant_data_fp32 = key_quant_data_int8.clone().to(torch.float)
value_quant_data_int8[:,0::2,:] = value_cache_data << 4 >> 4
value_quant_data_int8[:,1::2,:] = value_cache_data >> 4
value_quant_data_fp32 = value_quant_data_int8.clone().to(torch.float)
else:
key_quant_data_fp32 = key_cache_data.clone().to(torch.float)
value_quant_data_fp32 = value_cache_data.clone().to(torch.float)
if quant_layout == 'per_channel':
key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale[..., None, :]
value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale[..., None, :]
else: # per token
key_cache_scale_data = key_cache_scale[batch_id]
value_cache_scale_data = value_cache_scale[batch_id]
key_quant_data_fp32 = key_quant_data_fp32 * key_cache_scale_data[..., None]
value_quant_data_fp32 = value_quant_data_fp32 * value_cache_scale_data[..., None]
key_quant_data_fp32 = key_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :] #cut vaild cache
value_quant_data_fp32 = value_quant_data_fp32[:, cache_offset:cache_offset+max_cache_len, :]
# head_num_kv, max_cache_len, headsize-> max_cache_len, head_num_kv, headsize
key_cache_mem[i] = key_quant_data_fp32.transpose(0,1)
value_cache_mem[i] = value_quant_data_fp32.transpose(0,1)
return key_cache_mem, value_cache_mem
def scale_dot_attn(q_sess, key_cache, value_cache, cu_seq_lens_q, cu_seq_lens_cache, alibi_slope, is_causal, softmax_scale):
# cache_shape: [batch, cache_len, head_kv,head_size]
batch = cu_seq_lens_q.size(0) - 1
head_num_q = q_sess.size(-2)
head_num_kv = key_cache.size(-2)
assert head_num_q >= head_num_kv and head_num_q % head_num_kv == 0
group = head_num_q // head_num_kv
inf = 1e6
device= 'mlu'
out_list = []
for i in range(batch):
q = q_sess[cu_seq_lens_q[i]:cu_seq_lens_q[i+1], ...]
k = key_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...] # [cache_len, head_num_kv, head_size]
v = value_cache[cu_seq_lens_cache[i]:cu_seq_lens_cache[i+1], ...]
k = torch.repeat_interleave(k, group, dim=1) #[cache_len, head_num_q, head_size]
v = torch.repeat_interleave(v, group, dim=1)
qk = torch.einsum('qhd,khd->hqk', q, k) * softmax_scale
seq_q, seq_k = q.size(0), k.size(0)
if alibi_slope is not None:
slope = alibi_slope.reshape(1, head_num_q, 1, 1)
slope_bias = torch.zeros(1, head_num_q, seq_q, seq_k).to(device=device)
if is_causal:
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).to(device=device)
slope_bias = relative_pos * slope
else:
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
col_idx = torch.arange(seq_k, dtype=torch.long)
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).to(device=device)
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
qk += (slope_bias.squeeze(0))
if is_causal:
assert seq_q <= seq_k, "seq_q <= seq_k if causal=True"
zeros = torch.zeros(seq_q, seq_k-seq_q, dtype=torch.float, device="mlu")
tri = torch.full((seq_q, seq_q), -inf, dtype=torch.float, device="mlu").triu(diagonal=1)
mask = torch.cat([zeros, tri], dim=1) # (q, k-q) + (q, q) => (q, k)
qk += mask
attn = torch.softmax(qk, dim=-1, dtype=torch.float).to(q.dtype)
qkv = torch.einsum('hqk,khd->qhd', attn, v)
out_list.append(qkv)
output = torch.cat(out_list, dim=0)
return output
q_sess, k_sess, v_sess, key_cache1, value_cache1, key_cache_scale1, value_cache_scale1, cache_lens1,\
cache_seq_offset1, quant_bit1, quant_layout1, max_cache_len1,\
key_cache2, value_cache2, key_cache_scale2, value_cache_scale2, cache_lens2,\
cache_seq_offset2, quant_bit2, quant_layout2, max_cache_len2,\
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale = args
#1. dequant cache
#input cache shape [max_batch, head_num_kv, cache_mem_len, head_size]
#output cache shape [batch, max_cache_len, head_num_kv, head_size]
key_cache_mem1, value_cache_mem1 = dequant_from_cache(key_cache1, value_cache1, key_cache_scale1,
value_cache_scale1, cache_bs_id, cache_lens1, cache_seq_offset1,
quant_bit1, quant_layout1, max_cache_len1)
key_cache_mem2, value_cache_mem2 = None, None
if key_cache2 is not None:
key_cache_mem2, value_cache_mem2 = dequant_from_cache(key_cache2, value_cache2, key_cache_scale2,
value_cache_scale2, cache_bs_id, cache_lens2, cache_seq_offset2,
quant_bit2, quant_layout2, max_cache_len2)
#2. concat cache
batch = cache_bs_id.size(0)
head_num_kv = k_sess.size(-2)
head_size = k_sess.size(-1)
# concat cache1 and cache2
if key_cache2 is not None:
cache_lens = cache_lens1 + cache_lens2
max_cache_len = max_cache_len1 + max_cache_len2
key_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu')
value_cache_mem = torch.zeros(batch, max_cache_len, head_num_kv, head_size, dtype=torch.float, device='mlu')
for i in range(batch): #concat
len1 = cache_lens1[i]
len2 = cache_lens2[i]
key_cache_mem[i, :len1, ...] = key_cache_mem1[i, :len1, ...]
key_cache_mem[i, len1:len1+len2, ...] = key_cache_mem2[i, :len2, ...]
value_cache_mem[i, :len1, ...] = value_cache_mem1[i, :len1, ...]
value_cache_mem[i, len1:len1+len2, ...] = value_cache_mem2[i, :len2, ...]
else:
key_cache_mem, value_cache_mem = key_cache_mem1, value_cache_mem1
cache_lens = cache_lens1
#concat mem cache and sess cache
concat_cache_lens = cache_lens + sess_lens
cu_seq_lens_cache = torch.zeros((batch+1), dtype=torch.int32)
cu_seq_lens_cache[1:] = torch.cumsum(concat_cache_lens, dim=-1)
total_cache_len = cu_seq_lens_cache[-1]
concate_key_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu')
concate_value_cache = torch.zeros((total_cache_len, head_num_kv, head_size), dtype=torch.float, device='mlu')
for i in range(batch): #concat
mem_len = cache_lens[i]
sess_len = sess_lens[i]
off = cu_seq_lens_cache[i]
concate_key_cache[off:off+mem_len, ...] = key_cache_mem[i, :mem_len, ...]
concate_key_cache[off+mem_len:off+mem_len+sess_len, ...] = k_sess[i, :sess_len, ...]
concate_value_cache[off:off+mem_len, ...] = value_cache_mem[i, :mem_len, ...]
concate_value_cache[off+mem_len:off+mem_len+sess_len, ...] = v_sess[i, :sess_len, ...]
#3. attn
attn_out = scale_dot_attn(q_sess, concate_key_cache, concate_value_cache, cu_seq_lens_q, cu_seq_lens_cache, None, is_causal, softmax_scale)
return attn_out, concate_key_cache, concate_value_cache
def test_session_cache_attn_quant_kv(self):
max_batch = 64
cache_mem_len = 4096
head_num_kv = 1
head_num_q = 8
head_size = 128
batch = 32
max_sess_len = 100
max_cache_len = 3072
dtype_list=[torch.float16]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
quant_bit_list = [4, 8]
quant_layout_list = ['per_token', 'per_channel']
is_causal_list = [True, False]
arg = product(is_causal_list, dtype_list, quant_bit_list, quant_layout_list)
softmax_scale = 1 / math.sqrt(head_size)
for is_causal, dtype, quant_bit, quant_layout in arg:
print(f"is_causal: {is_causal}, dtype: {dtype}, quant_bit: {quant_bit}, quant_layout: {quant_layout}")
if quant_bit == 8:
cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size)
cache_shape_value = cache_shape_key
else:
cache_shape_key = (max_batch, head_num_kv, cache_mem_len, head_size//2)
cache_shape_value = (max_batch, head_num_kv, cache_mem_len//2, head_size)
if quant_layout == "per_token":
scale_shape = (max_batch, head_num_kv, cache_mem_len)
quant_mode = 1
else: # per channel
scale_shape = (head_num_kv, head_size)
quant_mode = 0
key_cache = torch.randint(-127, 128, cache_shape_key, device='mlu').to(torch.int8)
value_cache = torch.randint(-127, 128, cache_shape_value, device='mlu').to(torch.int8)
key_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float)
value_cache_scale = torch.randn(scale_shape, device='mlu', dtype=torch.float)
key_cache_scale = torch.fill(key_cache_scale, 0.01)
value_cache_scale = torch.fill(value_cache_scale, 0.01)
cache_lens = torch.randint(1, max_cache_len + 1, (batch,), dtype=torch.int32, device='mlu')
sess_lens = torch.randint(1, max_sess_len + 1, (batch,), dtype=torch.int32, device='mlu')
context_lens = cache_lens + sess_lens
max_cache_len_new = torch.max(context_lens)
cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32)
cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1)
total_sess_len = cu_seq_lens_q[-1]
cu_seq_lens_q=cu_seq_lens_q.mlu()
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables
mem_cache_seq_offset = torch.randint(0, cache_mem_len-max_cache_len, (batch,), dtype=torch.int32, device='mlu')
q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu')
k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
#fill sess cache
cache_seq_offset = torch.zeros((batch + 1), dtype=torch.int32)
cache_seq_offset[1:]= torch.cumsum(context_lens, dim=-1)
cache_seq_offset = cache_seq_offset.mlu()
context_seq_offset = cache_seq_offset[:-1]
sess_seq_offset = context_seq_offset + cache_lens
total_mem_len = cache_seq_offset[-1]
cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C
key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
for i in range(batch):
offset = sess_seq_offset[i]
len = sess_lens[i]
key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...]
value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...]
# baseline
base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(),
key_cache, value_cache, key_cache_scale, value_cache_scale, cache_lens, mem_cache_seq_offset,
quant_bit, quant_layout, max_cache_len,
None, None, None, None, None, None, -1, None, -1,
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale)
#tmo
#1. dequant cache
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache, value_cache, key_cache_scale,
value_cache_scale, cache_lens, max_cache_len, context_seq_offset,
cache_bs_id, mem_cache_seq_offset, quant_mode, quant_bit)
self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
#2. flash_attn
tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset,
None, None, max_sess_len, max_cache_len_new, softmax_scale,
is_causal, -1, -1, torch.float, False, None, None, None)
self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_session_cache_attn_kv_mixquant(self):
max_batch = 64
cache_mem_len_int4 = 4064
cache_mem_len_int8 = 32
head_num_kv = 1
head_num_q = 8
head_size = 128
batch = 32
max_sess_len = 100
max_cache_len_int4 = 3072
max_cache_len_int8 = 32
dtype_list=[torch.float16]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
is_causal_list = [True, False]
arg = product(is_causal_list, dtype_list)
softmax_scale = 1 / math.sqrt(head_size)
for is_causal, dtype in arg:
print(f"is_causal: {is_causal}, dtype: {dtype}")
cache_shape_int8 = (max_batch, head_num_kv, cache_mem_len_int8, head_size)
cache_shape_int4_key = (max_batch, head_num_kv, cache_mem_len_int4, head_size//2)
cache_shape_int4_value = (max_batch, head_num_kv, cache_mem_len_int4//2, head_size)
scale_shape_int4 = (max_batch, head_num_kv, cache_mem_len_int4) #per_token
quant_mode_int4 = 1
scale_shape_int8 = (head_num_kv, head_size) # per channel
quant_mode_int8 = 0
key_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8)
value_cache_int8 = torch.randint(-127, 128, cache_shape_int8, device='mlu').to(torch.int8)
key_cache_int4 = torch.randint(-127, 128, cache_shape_int4_key, device='mlu').to(torch.int8)
value_cache_int4 = torch.randint(-127, 128, cache_shape_int4_value, device='mlu').to(torch.int8)
key_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float)
value_cache_scale_int4 = torch.randn(scale_shape_int4, device='mlu', dtype=torch.float)
key_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float)
value_cache_scale_int8 = torch.randn(scale_shape_int8, device='mlu', dtype=torch.float)
key_cache_scale_int4 = torch.fill(key_cache_scale_int4, 0.01)
value_cache_scale_int4 = torch.fill(value_cache_scale_int4, 0.01)
key_cache_scale_int8 = torch.fill(key_cache_scale_int8, 0.01)
value_cache_scale_int8 = torch.fill(value_cache_scale_int8, 0.01)
cache_lens_int4 = torch.randint(1, max_cache_len_int4 + 1, (batch,), dtype=torch.int32, device='mlu')
cache_lens_int8 = torch.randint(1, max_cache_len_int8 + 1, (batch,), dtype=torch.int32, device='mlu')
sess_lens = torch.randint(max_sess_len-1, max_sess_len, (batch,), dtype=torch.int32, device='mlu')
cu_seq_lens_q = torch.zeros(batch+1, dtype=torch.int32)
cu_seq_lens_q[1:] = torch.cumsum(sess_lens, dim=-1)
total_sess_len = cu_seq_lens_q[-1]
cu_seq_lens_q=cu_seq_lens_q.mlu()
context_lens = cache_lens_int4 + cache_lens_int8 + sess_lens
max_cache_len_new = torch.max(context_lens)
cache_bs_id = random.sample([*range(0, max_batch)], batch)
cache_bs_id = torch.IntTensor(cache_bs_id).mlu() #block_tables
cache_seq_offset_int4 = torch.randint(0, cache_mem_len_int4-max_cache_len_int4, (batch,), dtype=torch.int32, device='mlu')
cache_seq_offset_int8 = None
q_sess = torch.randn(total_sess_len, head_num_q, head_size, dtype=dtype, device='mlu')
k_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
v_sess = torch.randn(batch, max_sess_len, head_num_kv, head_size, dtype=dtype, device='mlu')
#fill sess cache
cache_seq_offset1 = torch.zeros((batch + 1), dtype=torch.int32)
cache_seq_offset2 = torch.zeros((batch + 1), dtype=torch.int32)
cache_seq_offset1[1:]= torch.cumsum(context_lens, dim=-1)
cache_seq_offset2[:-1] = cache_seq_offset1[:-1] + cache_lens_int4.cpu()
cache_seq_offset1 = cache_seq_offset1.mlu()
cache_seq_offset2 = cache_seq_offset2.mlu()
context_seq_offset1 = cache_seq_offset1[:-1]
context_seq_offset2 = cache_seq_offset2[:-1]
sess_seq_offset = context_seq_offset2 + cache_lens_int8
total_mem_len = cache_seq_offset1[-1]
cache_shape_mem = (total_mem_len, head_num_kv, head_size)#NT,H,C
key_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
value_cache_mem = torch.zeros(cache_shape_mem, dtype= dtype, device='mlu')
for i in range(batch):
offset = sess_seq_offset[i]
len = sess_lens[i]
key_cache_mem[offset:offset+len, ...] = k_sess[i, :len, ...]
value_cache_mem[offset:offset+len, ...] = v_sess[i, :len, ...]
# baseline
base_output, base_key_cache, base_value_cache = self.op_impl_base(q_sess.float(), k_sess.float(), v_sess.float(),
key_cache_int4, value_cache_int4, key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4,
cache_seq_offset_int4, 4, 'per_token', max_cache_len_int4,
key_cache_int8, value_cache_int8, key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8,
cache_seq_offset_int8, 8, 'per_channel', max_cache_len_int8,
sess_lens, cache_bs_id, cu_seq_lens_q, is_causal, softmax_scale)
#tmo
#1. dequant cache
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int4, value_cache_int4,
key_cache_scale_int4, value_cache_scale_int4, cache_lens_int4, max_cache_len_int4,
context_seq_offset1, cache_bs_id, cache_seq_offset_int4, quant_mode_int4, 4)
ops.dequant_from_linear_cache(key_cache_mem, value_cache_mem, key_cache_int8, value_cache_int8,
key_cache_scale_int8, value_cache_scale_int8, cache_lens_int8, max_cache_len_int8,
context_seq_offset2, cache_bs_id, cache_seq_offset_int8, quant_mode_int8, 8)
self.assertTensorsEqual(base_key_cache.cpu().float(), key_cache_mem.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(base_value_cache.cpu().float(), value_cache_mem.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
#2. flash_attn
tmo_output = ops.flash_attention(q_sess, key_cache_mem, value_cache_mem, None, cu_seq_lens_q, cache_seq_offset1,
None, None, max_sess_len, max_cache_len_new, softmax_scale,
is_causal, -1, -1, torch.float, False, None, None, None)
self.assertTensorsEqual(base_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestSessionCacheAttnOp))

View File

@@ -0,0 +1,454 @@
import math
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
def gen_args(batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads):
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size)).mlu().half()
input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
max_context_len = int(max(context_lens))
if is_pagedattn:
block_size = 16
else:
block_size = max_seqlen
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
scale_shape = (num_blocks, num_kv_heads, block_size)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
if kv_data_type is not torch.int8:
key_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu()
value_cache = torch.randn(size=cache_shape, dtype=torch.float16).mlu()
key_cache_scale = None
value_cache_scale = None
else:
key_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu()
value_cache = torch.zeros(cache_shape).uniform_(-128, 128).to(kv_data_type).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
alibi_slopes = None
if has_alibi:
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
return (input_q.contiguous(), key_cache, value_cache, torch.empty_like(input_q), block_tables, context_lens, None,
key_cache_scale, value_cache_scale, alibi_slopes, max_context_len, -1, -1, softmax_scale, False, -1)
def gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, kv_dtype, max_seqlen, seq_q, head_num, num_kv_heads):
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype = dtype).mlu()
input_q = input_qkv[:,0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
max_context_len = int(max(context_lens))
if is_pagedattn:
block_size = 16
else:
block_size = max_seqlen
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size)
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
scale_shape = (num_blocks, num_kv_heads, block_size)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
if kv_dtype is torch.int8:
key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu()
key_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
else:
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
key_cache_scale = None
value_cache_scale = None
alibi_slopes = None
if has_alibi:
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
alibi_slopes.uniform_(0, 0.125)
return input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len
class TestSingleQueryAttnOp(BtTestCase):
def op_impl_base(self, *args):
q, k_cache, v_cache, out, block_tables, context_lens, k_cache_quant_scale, v_cache_quant_scale, \
alibi_slopes, max_contxt_len, windows_size_left, windows_size_right, softmax_scale, return_lse, \
kv_cache_quant_bit_size = args
base_output = single_query_cached_kv_attn(q, k_cache, v_cache, block_tables, context_lens, k_cache_quant_scale,
v_cache_quant_scale, alibi_slopes, windows_size_left, windows_size_right, softmax_scale, return_lse)
return base_output
def test_single_query_attention(self):
head_num = 16
batch_list = [5, 12]
num_kv_heads = 4
head_size_list = [(128, 128), (192, 384)]
seq_len_list = [512]
is_pagedattn_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
is_pagedattn_list = [False]
has_alibi_list = [True, False]
seq_q_list = [1, 5]
window_size_list = [(-1, -1), (10, -1)]
data_type_list = [torch.float16, torch.float]
if torch_mlu.mlu.is_bf16_supported():
data_type_list.append(torch.bfloat16)
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, seq_len_list, seq_q_list, window_size_list, data_type_list)
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, max_seqlen, seq_q, (window_size_left, window_size_right), dtype in args:
print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, seq_q {}, window_size_left {},\
window_size_right {}, dtype {}, testing...".format(
batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, seq_q, window_size_left, window_size_right, dtype))
# prepare input
params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, dtype, dtype, max_seqlen, seq_q, head_num, num_kv_heads)
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
softmax_scale = 1 / math.sqrt(head_size)
torch_output = self.op_impl_base(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, False, -1)
tmo_output1 = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(),
0.003, use_MSE=True)
if seq_q == 1:
torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True)
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
0.0001, use_MSE=True)
# @unittest.skip("not test")
def test_single_query_attention_quantize_kv(self):
head_num = 16
batch_list = [5, 12]
num_kv_heads = 4
head_size_list = [(128, 128), (16, 384)]
seq_len_list = [512]
is_pagedattn_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
is_pagedattn_list = [False]
has_alibi_list = [True, False]
quant_mode_list = ['per_token', 'per_channel']
kv_data_type_list = [torch.int8]
seq_q_list = [1, 5]
window_size_list = [(-1, -1), (10, -1)]
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list)
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, kv_data_type, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right) in args:
print("batch: {}, max_seqlen: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, \
quant_mode {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format(
batch, max_seqlen, head_size, head_size_v, is_pagedattn, has_alibi, kv_data_type, quant_mode, seq_q, window_size_left, window_size_right))
# prepare input
params = gen_params(batch, head_size, head_size_v, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads)
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
softmax_scale = 1 / math.sqrt(head_size)
torch_output_contiguous = self.op_impl_base(input_q.contiguous(),
key_cache, value_cache, None, block_tables, context_lens,
key_cache_scale, value_cache_scale, alibi_slopes,
max_context_len, window_size_left, window_size_right, softmax_scale, False, -1)
tmo_output_contiguous = ops.single_query_cached_kv_attn(input_q.contiguous(),
key_cache, value_cache, None, block_tables, context_lens,
key_cache_scale, value_cache_scale, alibi_slopes,
max_context_len, window_size_left, window_size_right, softmax_scale)
self.assertTensorsEqual(torch_output_contiguous.cpu().float(), tmo_output_contiguous.cpu().float(),
0.003, use_MSE=True)
if seq_q == 1:
torch_output, torch_lse = self.op_impl_base(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True, -1)
tmo_output, tmo_lse = ops.single_query_cached_kv_attn(input_q, key_cache, value_cache,
None, block_tables, context_lens, key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len, window_size_left, window_size_right, softmax_scale, True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.005, use_MSE=True)
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(), 0.0003, use_MSE=True)
def test_single_query_attention_int4_kv(self):
head_num = 16
batch_list = [5, 12]
num_kv_heads = 4
head_size_list = [(64, 128), (256, 128), (64, 384)]
seq_len_list = [512]
is_pagedattn_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
is_pagedattn_list = [False]
has_alibi_list = [True, False]
quant_mode_list = ['per_token', 'per_channel', 'per_token_group']
data_type_list = [torch.float, torch.half]
if torch_mlu.mlu.is_bf16_supported():
data_type_list.append(torch.bfloat16)
kv_data_type = torch.int8
seq_q_list = [1, 5]
#int4 range
quant_bit = 4
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
window_size_list = [(-1, -1), (20, -1)]
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, quant_mode_list, seq_len_list, seq_q_list, window_size_list, data_type_list)
for batch, (head_size, head_size_v), is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, (window_size_left, window_size_right), data_type in args:
print("kv4: batch: {}, head_size: {}, head_size_v: {}, is_pagedattn: {}, has_alibi {}, quant_mode {}, max_seqlen: {}, seq_q {}, \
window_size_left {}, window_size_right {}, data_type {}, testing...".format(
batch, head_size, head_size_v, is_pagedattn, has_alibi, quant_mode, max_seqlen, seq_q, window_size_left, window_size_right, data_type))
# prepare input
input_qkv = torch.randn((batch, seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
input_q = input_qkv[..., 0:head_num,:]
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
max_context_len = context_lens.max().item()
if max_seqlen % 2 == 1:
max_seqlen += 1
if is_pagedattn:
block_size = 16
else:
block_size = max_seqlen
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size)
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, int(head_size/2))
cache_shape_v_int4 = (num_blocks, num_kv_heads, int(block_size/2), head_size_v)
cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size_v, int(block_size/2))
if quant_mode == "per_channel":
scale_shape_k = (num_kv_heads, head_size)
scale_shape_v = (num_kv_heads, head_size_v)
elif quant_mode == "per_token":
scale_shape_k = (num_blocks, num_kv_heads, block_size)
scale_shape_v = (num_blocks, num_kv_heads, block_size)
elif quant_mode == "per_token_group":
scale_shape_k = (num_blocks, num_kv_heads, block_size, 1) #group_size = head_size
scale_shape_v = (num_blocks, num_kv_heads, block_size, 1)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
key_cache = torch.zeros(cache_shape_k).uniform_(int_min, int_max).to(kv_data_type).mlu()
value_cache = torch.zeros(cache_shape_v).uniform_(int_min, int_max).to(kv_data_type).mlu()
key_cache_scale = torch.randn(size=scale_shape_k, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(size=scale_shape_v, dtype=torch.float32).mlu()
key_cache_view = key_cache.reshape(-1, head_size)
value_cache_view = value_cache.transpose(2, 3).reshape(-1, block_size)
key_cache_int4 = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4)
value_cache_int4 = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous()
alibi_slopes = None
if has_alibi:
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
torch_output = self.op_impl_base(input_q.contiguous(),
key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, False, 4)
tmo_output_contigous = ops.single_query_cached_kv_attn(input_q.contiguous(),
key_cache_int4, value_cache_int4,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, False, 4)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_contigous.cpu().float(),
0.003, use_MSE=True)
tmo_output_inplace = torch.empty((batch, seq_q, head_num, head_size_v), dtype=data_type, device="mlu")
ops.single_query_cached_kv_attn(input_q, key_cache_int4, value_cache_int4,
tmo_output_inplace, block_tables, context_lens,
key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, False, 4)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output_inplace.cpu().float(),
0.003, use_MSE=True)
if seq_q == 1:
torch_output, torch_lse = self.op_impl_base(input_q,
key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, True, 4)
tmo_output1, tmo_lse = ops.single_query_cached_kv_attn(input_q,
key_cache_int4, value_cache_int4,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, True, 4)
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
0.003, use_MSE=True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output1.cpu().float(),
0.003, use_MSE=True)
def test_single_query_attention_inplace(self):
head_num = 16
batch_list = [5]
num_kv_heads = 4
head_size_list = [64]
seq_len_list = [512]
is_pagedattn_list = [False, True]
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
is_pagedattn_list = [False]
has_alibi_list = [True, False]
kv_data_type_list = [torch.int8, torch.float16]
seq_q_list = [1, 5]
window_size_list = [(-1, -1), (20, -1)]
args = product(batch_list, head_size_list, is_pagedattn_list, has_alibi_list, kv_data_type_list, seq_len_list, seq_q_list, window_size_list)
for batch, head_size, is_pagedattn, has_alibi, kv_data_type, max_seqlen, seq_q, (window_size_left, window_size_right) in args:
print("batch: {}, max_seqlen: {}, head_size: {}, is_pagedattn: {}, has_alibi {}, kv_datatype {}, seq_q {}, window_size_left {}, window_size_right {}, testing...".format(
batch, max_seqlen, head_size, is_pagedattn, has_alibi, kv_data_type, seq_q, window_size_left, window_size_right))
# prepare input
params = gen_params(batch, head_size, head_size, is_pagedattn, has_alibi, torch.float16, kv_data_type, max_seqlen, seq_q, head_num, num_kv_heads)
input_q, key_cache, value_cache, block_tables, context_lens, key_cache_scale, value_cache_scale, alibi_slopes, max_context_len = params
softmax_scale = 1 / math.sqrt(head_size)
tmo_output = torch.empty_like(input_q)
tmo_output_contigous = torch.empty_like(input_q)
torch_output = self.op_impl_base(input_q, key_cache, value_cache, None,
block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale, False, -1)
ops.single_query_cached_kv_attn(input_q, key_cache, value_cache, tmo_output,
block_tables, context_lens, key_cache_scale,
value_cache_scale, alibi_slopes, max_context_len,
window_size_left, window_size_right, softmax_scale)
ops.single_query_cached_kv_attn(input_q.contiguous(), key_cache, value_cache,
tmo_output_contigous, block_tables, context_lens,
key_cache_scale, value_cache_scale, alibi_slopes,
max_context_len, window_size_left, window_size_right, softmax_scale)
self.assertTensorsEqual(tmo_output.cpu().float(), tmo_output_contigous.cpu().float(),
0.000, use_MSE=True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True)
# 防呆测试
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
func = ops.single_query_cached_kv_attn
batch, seq_q, head_num, num_kv_heads, head_size_qk, head_size_v, max_seqlen, softmax_scale = 5, 1, 8, 3, 64, 128, 512, 0.625
dtype = torch.float16
input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu()
context_lens = torch.randint(seq_q, max_seqlen + 1, (batch, ), dtype=torch.int32).mlu()
max_context_len = int(max(context_lens))
block_size = 16
num_blocks = batch * ((max_seqlen + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk)
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
key_cache_scale = None
value_cache_scale = None
window_size_left, window_size_right = 10, 10
self.assertException("only support windows_size_right < 0 currently.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
window_size_right = -1
self.assertException("num_heads need be mutiple of num_kv_heads.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
num_kv_heads = 1
cache_shape_k = (num_blocks, num_kv_heads, block_size, head_size_qk)
cache_shape_v = (num_blocks, num_kv_heads, block_size, head_size_v)
key_cache = torch.randn(size=cache_shape_k, dtype=dtype).mlu()
value_cache = torch.randn(size=cache_shape_v, dtype=dtype).mlu()
self.assertException("illegal quant bit size, only support 4, 8 or -1.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, 16)
input1 = torch.randn((batch, seq_q, head_num, head_size_qk * 2), dtype = dtype).mlu()
input = input1[..., 0::2]
self.assertException("q last two dim need be contiguous.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
input = torch.randn((batch, seq_q, head_num, head_size_qk), dtype = dtype).mlu()
key_cache1 = key_cache[..., :8, :]
value_cache1 = value_cache[..., :8, :]
self.assertException("k_cache and v_cache need be contiguous.",
func, input, key_cache1, value_cache1,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
self.assertException("q_ori need be mlu tensor.",
func, input.cpu(), key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
input = torch.randn((batch, 5, head_num, head_size_qk), dtype = dtype).mlu()
self.assertException("return lse only support seq_q = 1 currently.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
self.assertException("block_tables type need be torch::kInt32 or torch::kLong.",
func, input, key_cache, value_cache,
None, block_tables.float(), context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, False, -1)
self.assertException("context_lens type need be torch::kInt32.",
func, input, key_cache, value_cache,
None, block_tables, context_lens.to(torch.int64),
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, False, -1)
self.assertException("context_lens need be contiguous.",
func, input, key_cache, value_cache,
None, block_tables, context_lens[0::2],
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, False, -1)
scale_shape = (num_blocks)
key_cache = torch.zeros(cache_shape_k).uniform_(-128, 127).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape_v).uniform_(-128, 127).to(torch.int8).mlu()
key_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu()
value_cache_scale = torch.randn(scale_shape, dtype=torch.float32).mlu()
self.assertException("k_cache_quant_scale must be 2d or 3d or 4d.",
func, input, key_cache, value_cache,
None, block_tables, context_lens,
key_cache_scale, value_cache_scale,
None, max_context_len,
window_size_left, window_size_right, softmax_scale, True, -1)
def test_inductor(self):
batch, seq_q, head_num, num_kv_heads, head_size, max_seqlen = 1, 5, 16, 16, 128, 512
is_pagedattn_list = [False, True] if "MLU3" not in torch.mlu.get_device_name() else [False]
has_alibi_list = [True, False]
test_flags = product(is_pagedattn_list, has_alibi_list)
for is_pagedattn, has_alibi in test_flags:
args = gen_args(batch, head_size, is_pagedattn, has_alibi, torch.int8, max_seqlen, seq_q, head_num, num_kv_heads)
self.base_opcheck(torch.ops.torch_mlu_ops.single_query_cached_kv_attn, args)
if __name__ == '__main__':
exit(run_unittest(TestSingleQueryAttnOp))

View File

@@ -0,0 +1,393 @@
import math
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
def gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len, data_type, quant_bit, quant_mode, is_normal=True):
int_max = float(2 ** (quant_bit - 1) - 1)
int_min = -float(2 ** (quant_bit - 1))
context_lens = torch.randint(seq_q, seq_len + 1, (batch, ), dtype=torch.int32).mlu()
if is_normal is False and batch > 3: # replace some batch's context to 0
num = batch // 3
index = torch.randint(0, batch, (num,))
context_lens[index] = 0
max_context_len = context_lens.max().item()
block_size = 16
if is_pagedattn is False:
block_size = seq_len
num_blocks = batch * ((seq_len + block_size - 1) // block_size)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if quant_mode == "per_token":
scale_shape = (num_blocks, num_kv_heads, block_size, 1)
else: # per channel
scale_shape = (num_kv_heads, head_size)
if quant_bit == 4:
cache_shape_k_int4 = (num_blocks, num_kv_heads, block_size, head_size//2)
cache_shape_v_int4 = (num_blocks, num_kv_heads, block_size//2, head_size)
cache_shape_v_int4_tmp = (num_blocks, num_kv_heads, head_size, block_size//2)
key_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache_int8 = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
# pre_process int4_kv_cache
key_cache_view = key_cache_int8.reshape(-1, head_size)
value_cache_view = value_cache_int8.transpose(2, 3).reshape(-1, block_size)
key_cache = PairlyPackInt8(key_cache_view).view(cache_shape_k_int4)
value_cache = PairlyPackInt8(value_cache_view).view(cache_shape_v_int4_tmp).transpose(2,3).contiguous()
key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == 8:
key_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
value_cache = torch.zeros(cache_shape).uniform_(int_min, int_max).to(torch.int8).mlu()
key_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
value_scale = torch.randn(size=scale_shape, dtype=torch.float32).mlu()
elif quant_bit == -1:
key_cache = torch.randn(cache_shape, dtype=data_type).mlu()
value_cache = torch.randn(cache_shape, dtype=data_type).mlu()
key_scale = None
value_scale = None
else:
print("gen case error, quant_bit_lp must be in {-1, 4, 8}")
block_tables = random.sample(range(0, num_blocks), batch * max_num_blocks_per_seq)
block_tables = torch.tensor(block_tables, dtype=torch.int32).mlu().view(batch, max_num_blocks_per_seq)
output = [key_cache, value_cache, key_scale, value_scale, context_lens, block_tables]
if quant_bit == 4:
output.append(key_cache_int8)
output.append(value_cache_int8)
return output
def concate_cache_linear(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2,
scalek1, scalek2, scalev1, scalev2):
if scalek1 is not None:
if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size]
scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1])
scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1])
elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
scalek1 = scalek1.reshape(*scalek1.shape, 1)
scalev1 = scalev1.reshape(*scalev1.shape, 1)
cachek1 *= scalek1
cachev1 *= scalev1
if scalek2 is not None:
if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size]
scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1])
scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1])
elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
scalek2 = scalek2.reshape(*scalek2.shape, 1)
scalev2 = scalev2.reshape(*scalev2.shape, 1)
cachek2 *= scalek2
cachev2 *= scalev2
new_context = context1 + context2
seq_len1 = cachek1.shape[2]
seq_len2 = cachek2.shape[2]
new_max_context = seq_len1 + seq_len2
batch = cachek1.shape[0]
num_head = cachek1.shape[1]
head_size = cachek1.shape[3]
new_cache_k = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32)
new_cache_v = torch.randn(batch, num_head, new_max_context, head_size, dtype=torch.float32)
new_block_table = torch.arange(0, batch)
new_block_table = new_block_table.view(batch, 1)
for i in range(batch):
len1 = context1[i]
len2 = context2[i]
block_id1 = block_tables1[i]
block_id2 = block_tables2[i]
new_cache_k[i, :, :len1, :] = cachek1[block_id1, :, :len1, :]
new_cache_v[i, :, :len1, :] = cachev1[block_id1, :, :len1, :]
new_cache_k[i, :, len1:len1 + len2, :] = cachek2[block_id2, :, :len2, :]
new_cache_v[i, :, len1:len1 + len2, :] = cachev2[block_id2, :, :len2, :]
return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu()
# cache1 and cache2 are float
def concat_cache_paged(cachek1, cachek2, cachev1, cachev2, context1, context2, block_tables1, block_tables2,
scalek1, scalek2, scalev1, scalev2):
batch = context1.shape[0]
block_size = cachek1.shape[2]
if scalek1 is not None:
if scalek1.dim() == 2: # per_channel: [kv_head_num, head_size]
scalek1 = scalek1.reshape(1, scalek1.shape[0], 1, scalek1.shape[1])
scalev1 = scalev1.reshape(1, scalev1.shape[0], 1, scalev1.shape[1])
elif scalek1.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
scalek1 = scalek1.reshape(*scalek1.shape, 1)
scalev1 = scalev1.reshape(*scalev1.shape, 1)
cachek1 *= scalek1
cachev1 *= scalev1
if scalek2 is not None:
if scalek2.dim() == 2: # per_channel: [kv_head_num, head_size]
scalek2 = scalek2.reshape(1, scalek2.shape[0], 1, scalek2.shape[1])
scalev2 = scalev2.reshape(1, scalev2.shape[0], 1, scalev2.shape[1])
elif scalek2.dim() == 3: # per_token: [num_blocks, k_head_num, block_size]
scalek2 = scalek2.reshape(*scalek2.shape, 1)
scalev2 = scalev2.reshape(*scalev2.shape, 1)
cachek2 *= scalek2
cachev2 *= scalev2
new_context = context1 + context2
max_num_blocks_per_seq = block_tables1.shape[1] + block_tables1.shape[1]
new_cache_k = torch.concat((cachek1, cachek2), dim = 0)
new_cache_v = torch.concat((cachev1, cachev2), dim = 0)
num_block1 = cachek1.shape[0]
new_block_table = torch.zeros(batch, max_num_blocks_per_seq, dtype=torch.int32)
for i in range(batch):
len1 = context1[i]
len2 = context2[i]
block_num1 = (len1 + block_size - 1) // block_size
block_num2 = (len2 + block_size - 1) // block_size
len1_pad = block_num1 * block_size
block1 = block_tables1[i]
block2 = block_tables2[i]
new_block_table[i, :block_num1] = block1[:block_num1]
new_block_table[i, block_num1:block_num1 + block_num2] = block2[:block_num2] + num_block1
if len1 != len1_pad:
reg_block_id = new_block_table[i, block_num1 - 1] # last block of cache1
cat_block_id = new_block_table[i, block_num1] # frist block of cache2
reg_len = len1 % block_size
pad_len = len1_pad - len1
new_cache_k[reg_block_id, :, reg_len:, :] = new_cache_k[cat_block_id, :, :pad_len, :]
new_cache_v[reg_block_id, :, reg_len:, :] = new_cache_v[cat_block_id, :, :pad_len, :]
for j in range(block_num2-1):
block_id1 = new_block_table[i, block_num1 + j] # current
block_id2 = new_block_table[i, block_num1 + j + 1] # next
new_cache_k[block_id1, :, :reg_len, :] = new_cache_k[block_id1, :, pad_len:, :]
new_cache_k[block_id1, :, reg_len:, :] = new_cache_k[block_id2, :, :pad_len, :]
new_cache_v[block_id1, :, :reg_len, :] = new_cache_v[block_id1, :, pad_len:, :]
new_cache_v[block_id1, :, reg_len:, :] = new_cache_v[block_id2, :, :pad_len, :]
block_id = new_block_table[i, block_num1 + block_num2 - 1]
new_cache_k[block_id, :, :reg_len, :] = new_cache_k[block_id, :, pad_len:, :]
new_cache_v[block_id, :, :reg_len, :] = new_cache_v[block_id, :, pad_len:, :]
return new_cache_k.mlu(), new_cache_v.mlu(), new_block_table.mlu(), new_context.mlu()
class TestSingleQueryMixedKVAttnOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
q = create_tensor_from_dic(dic['q'])
k_cache_lp = create_tensor_from_dic(dic['k_cache_lp'])
v_cache_lp = create_tensor_from_dic(dic['v_cache_lp'])
k_cache_hp = create_tensor_from_dic(dic['k_cache_hp'])
v_cache_hp = create_tensor_from_dic(dic['v_cache_hp'])
out = create_tensor_from_dic(dic['out'])
block_tables_lp = dic['block_tables_lp']['data']
block_tables_hp = dic['block_tables_hp']['data']
context_lens_lp = dic['context_lens_lp']['data']
context_lens_hp = dic['context_lens_hp']['data']
k_cache_quant_scale_lp = create_tensor_from_dic(dic['k_cache_quant_scale_lp'])
v_cache_quant_scale_lp = create_tensor_from_dic(dic['v_cache_quant_scale_lp'])
k_cache_quant_scale_hp = create_tensor_from_dic(dic['k_cache_quant_scale_hp'])
v_cache_quant_scale_hp = create_tensor_from_dic(dic['v_cache_quant_scale_hp'])
alibi_slopes = create_tensor_from_dic(dic['alibi_slopes'])
max_contxt_len_lp = dic['max_contxt_len_lp']['data']
max_contxt_len_hp = dic['max_contxt_len_hp']['data']
softmax_scale = dic['softmax_scale']['data']
return_lse = dic['return_lse']['data']
kv_cache_quant_bit_size_lp = dic['kv_cache_quant_bit_size_lp']['data']
kv_cache_quant_bit_size_hp = dic['kv_cache_quant_bit_size_hp']['data']
self.launch(q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp,
block_tables_hp, context_lens_lp, context_lens_hp, k_cache_quant_scale_lp,
v_cache_quant_scale_lp, k_cache_quant_scale_hp, v_cache_quant_scale_hp,
alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale,
return_lse, kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp)
def launch(self, *args):
torch_output, torch_lse = self.op_impl_base(*args)
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(*args)
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
0.0003, use_MSE=True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True)
def op_impl_base(self, *args):
q, k_cache_lp, v_cache_lp, k_cache_hp, v_cache_hp, out, block_tables_lp, block_tables_hp, \
context_lens_lp, context_lens_hp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, k_cache_quant_scale_hp, \
v_cache_quant_scale_hp, alibi_slopes, max_contxt_len_lp, max_contxt_len_hp, softmax_scale, return_lse, \
kv_cache_quant_bit_size_lp, kv_cache_quant_bit_size_hp = args
if kv_cache_quant_bit_size_lp == 4:
num_blocks, num_kv_heads, block_size_lp, head_size = v_cache_lp.size()
block_size = block_size_lp * 2
k_cache_lp = UnpackInt4(k_cache_lp).reshape(num_blocks, num_kv_heads, block_size, head_size)
v_cache_lp = UnpackInt4(v_cache_lp.transpose(2,3)).reshape(num_blocks, num_kv_heads, head_size, block_size).transpose(2,3)
torch_output_lp, torch_lse_lp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_lp.float(), v_cache_lp.float(),
block_tables_lp, context_lens_lp, k_cache_quant_scale_lp, v_cache_quant_scale_lp, alibi_slopes, -1, -1, softmax_scale, return_lse)
torch_output_hp, torch_lse_hp = single_query_cached_kv_attn(q.contiguous().float(), k_cache_hp.float(), v_cache_hp.float(),
block_tables_hp, context_lens_hp, k_cache_quant_scale_hp, v_cache_quant_scale_hp, alibi_slopes, -1, -1, softmax_scale, return_lse)
torch_output, torch_lse = update_out_and_lse_torch(torch_output_lp, torch_lse_lp, torch_output_hp, torch_lse_hp, None, None, None)
return (torch_output, torch_lse) if return_lse else torch_output
def test_single_query_mixedkv_attention(self):
head_num = 16
batch = 12
num_kv_heads = 4
seq_q = 1
head_size = 128
seq_len_lp = 512
seq_len_hp = 128
is_pagedattn_list = [True, False]
has_alibi_list = [True, False]
is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0
quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)]
data_type_list = [torch.float, torch.float16]
if torch_mlu.mlu.is_bf16_supported():
data_type_list.append(torch.bfloat16)
args = product(is_pagedattn_list, has_alibi_list, data_type_list, quant_bit_list, is_normal_list)
for is_pagedattn, has_alibi, data_type, quant_bit, is_normal in args:
quant_bit_lp, quant_bit_hp = quant_bit
print("test separate:{} + {}: batch:{}, seq_len_lp:{}, seq_len_hp:{}, head_size:{}, is_pagedattn:{}, has_alibi:{}, data_type:{}, is_normal:{} ...".format(
quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, has_alibi, data_type, is_normal))
if is_pagedattn:
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
continue
torch.manual_seed(1)
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
#gen k/v cache
params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp,
"per_token", is_normal)
params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp,
"per_channel")
if quant_bit_lp == 4:
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, _, _ = params_lp
else:
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp
key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp
max_context_len_lp = context_lens_lp.max().item()
max_context_len_hp = context_lens_hp.max().item()
alibi_slopes = None
if has_alibi:
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
torch_output, torch_lse = self.op_impl_base(input_q,
key_cache_lp, value_cache_lp,
key_cache_hp, value_cache_hp,
None, #output
block_tables_lp, block_tables_hp,
context_lens_lp, context_lens_hp,
key_scale_lp, value_scale_lp,
key_scale_hp, value_scale_hp,
alibi_slopes,
max_context_len_lp, max_context_len_hp,
softmax_scale, True,
quant_bit_lp, quant_bit_hp)
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q,
key_cache_lp, value_cache_lp,
key_cache_hp, value_cache_hp,
None, #output
block_tables_lp, block_tables_hp,
context_lens_lp, context_lens_hp,
key_scale_lp, value_scale_lp,
key_scale_hp, value_scale_hp,
alibi_slopes,
max_context_len_lp, max_context_len_hp,
softmax_scale, True,
quant_bit_lp, quant_bit_hp)
self.assertTensorsEqual(torch_lse.cpu().float(), tmo_lse.cpu().float(),
0.0003, use_MSE=True)
self.assertTensorsEqual(torch_output.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True)
def test_single_query_mixedkv_attention_concate(self):
head_num = 16
batch = 12
num_kv_heads = 4
seq_q = 1
head_size = 128
seq_len_lp = 512
seq_len_hp = 128
is_pagedattn_list = [True, False]
has_alibi = False #only support without alibi
is_normal_list = [True, False] # if false, lp_k/v_len of some batch is 0
quant_bit_list = [(-1, -1), (4, 8), (4, -1), (8, -1), (8, 8)]
data_type_list = [torch.float, torch.float16]
if torch_mlu.mlu.is_bf16_supported():
data_type_list.append(torch.bfloat16)
args = product(is_pagedattn_list, data_type_list, quant_bit_list, is_normal_list)
for is_pagedattn, data_type, quant_bit, is_normal in args:
quant_bit_lp, quant_bit_hp = quant_bit
print("test concate {} + {}: seq_len_lp: {}, seq_len_hp: {}, is_pagedattn: {}, data_type {}, is_normal {} ...".format(
quant_bit_lp, quant_bit_hp, batch, seq_len_lp, seq_len_hp, head_size, is_pagedattn, data_type, is_normal))
if is_pagedattn:
mlu_name = torch.mlu.get_device_name()
if "MLU3" in mlu_name:
print("pagedattn is not implement on mlu370, skip it")
continue
torch.manual_seed(1)
input_qkv = torch.randn((batch, 3 * seq_q, 3 * head_num, head_size), dtype=data_type).mlu()
input_q = input_qkv[:, 0:seq_q, 0:head_num,:].view(batch, seq_q, head_num, head_size)
#gen k/v cache
params_lp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_lp, data_type, quant_bit_lp,
"per_token", is_normal)
params_hp = gen_cache(batch, seq_q, num_kv_heads, head_size, is_pagedattn, seq_len_hp, data_type, quant_bit_hp,
"per_channel")
if quant_bit_lp == 4:
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp, key_cache_lp_torch, value_cache_lp_torch = params_lp
else:
key_cache_lp, value_cache_lp, key_scale_lp, value_scale_lp, context_lens_lp, block_tables_lp = params_lp
key_cache_lp_torch, value_cache_lp_torch = key_cache_lp, value_cache_lp
key_cache_hp, value_cache_hp, key_scale_hp, value_scale_hp, context_lens_hp, block_tables_hp = params_hp
max_context_len_lp = context_lens_lp.max().item()
max_context_len_hp = context_lens_hp.max().item()
alibi_slopes = None
if has_alibi:
alibi_slopes = torch.zeros((batch, head_num), dtype=torch.float32).mlu()
alibi_slopes.uniform_(0, 0.125)
softmax_scale = 1 / math.sqrt(head_size)
# concate cache
if is_pagedattn:
concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concat_cache_paged(
key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(),
value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(),
context_lens_lp.cpu(), context_lens_hp.cpu(),
block_tables_lp.cpu(), block_tables_hp.cpu(),
key_scale_lp.cpu() if key_scale_lp is not None else None,
key_scale_hp.cpu() if key_scale_hp is not None else None,
value_scale_lp.cpu() if value_scale_lp is not None else None,
value_scale_hp.cpu() if value_scale_hp is not None else None)
else:
concat_cache_k, concat_cache_v, concat_block_tables, concat_context = concate_cache_linear(
key_cache_lp_torch.float().cpu(), key_cache_hp.float().cpu(),
value_cache_lp_torch.float().cpu(), value_cache_hp.float().cpu(),
context_lens_lp.cpu(), context_lens_hp.cpu(),
block_tables_lp.cpu(), block_tables_hp.cpu(),
key_scale_lp.cpu() if key_scale_lp is not None else None,
key_scale_hp.cpu() if key_scale_hp is not None else None,
value_scale_lp.cpu() if value_scale_lp is not None else None,
value_scale_hp.cpu() if value_scale_hp is not None else None)
torch_output_concat, torch_lse_concat = single_query_cached_kv_attn(input_q.contiguous().float(),
concat_cache_k.float(), concat_cache_v.float(), concat_block_tables,
concat_context, None, None, alibi_slopes, -1, -1, softmax_scale, True)
tmo_output, tmo_lse = ops.single_query_mixed_cached_kv_attn(input_q,
key_cache_lp, value_cache_lp,
key_cache_hp, value_cache_hp,
None, #output
block_tables_lp, block_tables_hp,
context_lens_lp, context_lens_hp,
key_scale_lp, value_scale_lp,
key_scale_hp, value_scale_hp,
alibi_slopes,
max_context_len_lp, max_context_len_hp,
softmax_scale, True,
quant_bit_lp, quant_bit_hp)
if is_normal: # only compare lse when context_len_lp is normal, seq=0 case nan-value in lse
self.assertTensorsEqual(torch_lse_concat.cpu().float(), tmo_lse.cpu().float(),
0.0003, use_MSE=True)
self.assertTensorsEqual(torch_output_concat.cpu().float(), tmo_output.cpu().float(),
0.003, use_MSE=True)
def test_inductor(self):
return super().test_inductor()
if __name__ == '__main__':
exit(run_unittest(TestSingleQueryMixedKVAttnOp))

View File

@@ -0,0 +1,555 @@
import torch
import torch_mlu
import unittest
import torch_mlu_ops as tmo
from common_utils import *
from itertools import product
import random
import os
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), \
dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
token_count = cusum_token_count[1:] - cusum_token_count[0:-1]
return token_count, cusum_token_count
def gen_case(num_tokens,
topk,
hidden_size,
multi_scale,
need_gather,
num_expert,
expert_size,
start_expert_id,
dtype,
device):
input = torch.randn((num_tokens, hidden_size), dtype=dtype, device=device)
effective_count = num_tokens * topk if need_gather else num_tokens
if multi_scale:
token_count, cusum_token_count = generate_token_count(num_expert, effective_count)
token_count = token_count.to(device=device)
cusum_token_count = cusum_token_count.to(device=device)
scale = torch.randn((num_expert, hidden_size), dtype=torch.float32, device=device)
else:
token_count = None
cusum_token_count = None
scale = torch.randn((hidden_size), dtype=torch.float32, device=device)
if need_gather:
gather_ids = torch.randperm(num_tokens * topk, dtype=torch.int32) // topk
gather_ids = gather_ids.mlu()
else:
gather_ids = None
if expert_size < num_expert and multi_scale:
token_count = token_count[start_expert_id:]
cusum_token_count = cusum_token_count[start_expert_id:]
gather_index_start_position = cusum_token_count[0:1]
scale = scale[start_expert_id:]
effective_count = cusum_token_count[-1] - cusum_token_count[0]
else:
gather_index_start_position = None
if not need_gather:
gather_index_start_position = None
return input, scale, gather_ids, token_count, cusum_token_count, gather_index_start_position, effective_count
def per_token_smooth_quantize_base(x: torch.Tensor,
smooth: torch.Tensor,
zero: torch.Tensor = None,
token_count: torch.Tensor = None):
output_shape = x.size()
output_scale_shape = x.size()[0:-1]
output, output_scale = QuantByRow(x.flatten(0, -2) * smooth, 8)
return output.reshape(output_shape), output_scale.reshape(output_scale_shape)
def quantize_base(x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor = None
) -> torch.Tensor:
return (x * scale).round().clamp(-128.0, 127.0).to(torch.int8)
class TestSmoothQuantOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
x = create_tensor_from_dic(dic['x'])
smooth = create_tensor_from_dic(dic['smooth'])
zero = None if dic['zero']['data'] is None else create_tensor_from_dic(dic['zero'])
token_count = None if dic['token_count']['data'] is None else dic['token_count']['data']
gather_index = None if dic['gather_index']['data'] is None else dic['gather_index']['data']
gather_index_start_position = None if dic['gather_index_start_position']['data'] is None else dic['gather_index_start_position']['data']
output = None if dic['output']['data'] is None else create_tensor_from_dic(dic['output'])
output_scale = None if dic['output_scale']['data'] is None else create_tensor_from_dic(dic['output_scale'])
dynamic_quant = dic['dynamic_quant']['data']
self.launch(x, smooth, zero, token_count, gather_index, gather_index_start_position, output, output_scale, dynamic_quant)
def launch(self, *args):
tmo_out = tmo.moe_quantize(*args)
out_base = None if args[6] is None else args[6].clone()
scale_base = None if args[7] is None else args[7].clone()
args = list(args)
args[6] = out_base
args[7] = scale_base
torch_out = self.op_impl_base(*args)
if args[-1]:
self.assertTensorsEqual(torch_out[0].cpu().reshape(-1).float(),
tmo_out[0].cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_out[1].cpu().reshape(-1).float(),
tmo_out[1].cpu().reshape(-1).float(),
0.003, use_MSE=True, use_RAE=True)
else:
self.assertTensorsEqual(torch_out.cpu().reshape(-1).float(),
tmo_out.cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
def op_impl_base(self, *args):
x, smooth, zero, token_count, gather_index, gather_index_start_position, \
output, output_scale, dynamic_quant = args
input = x.to(dtype=torch.float32).cpu()
input_scale = smooth.cpu()
cusum_token_count = None
if token_count is not None:
token_count = token_count.cpu()
cusum_token_count = torch.zeros(token_count.shape[0] + 1, dtype=torch.int32)
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
if gather_index_start_position is not None:
gather_index_start_position = gather_index_start_position.cpu()
gather_index_start = 0
if gather_index_start_position is not None:
gather_index_start = gather_index_start_position[0]
if cusum_token_count is not None:
gather_index_end = cusum_token_count[-1] + gather_index_start
elif gather_index is not None:
gather_index_end = gather_index.numel()
else:
gather_index_end = input.numel()
if gather_index is not None:
gather_index = gather_index.cpu()
gathered_input = input[gather_index[gather_index_start : gather_index_end]]
else:
gathered_input = input[gather_index_start : gather_index_end]
if cusum_token_count is not None:
for i in range(token_count.shape[0]):
gathered_input[cusum_token_count[i] : cusum_token_count[i+1]] *= input_scale[i]
else:
gathered_input *= input_scale
if output is None:
if not dynamic_quant:
return gathered_input.round().clamp(-128.0, 127.0).to(torch.int8), None
else:
return QuantByRow(gathered_input, 8)
else:
if not dynamic_quant:
output.copy_(gathered_input.round().clamp(-128.0, 127.0).to(torch.int8))
output_scale = None
else:
out, scale = QuantByRow(gathered_input, 8)
output_fl = output.flatten()
output_fl[:out.numel()].copy_(out.flatten())
# output = output_fl
output_scale_fl = output_scale.flatten()
output_scale_fl[:scale.numel()].copy_(scale.flatten())
# output_scale = output_scale_fl
return (output, output_scale) if dynamic_quant else (output,)
def test_random_case(self):
torch.manual_seed(333)
test_cases = 100
num_tokens_list = torch.randint(low=1, high=4096, size=(test_cases, ), dtype=torch.int32)
topk_list = torch.randint(low=1, high=16, size=(test_cases, ), dtype=torch.int32)
hidden_size_list = torch.randint(low=128, high=8193, size=(test_cases, ), dtype=torch.int32)
num_expert_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
expert_size_list = torch.randint(low=1, high=129, size=(test_cases, ), dtype=torch.int32)
expert_size_list = torch.minimum(expert_size_list, num_expert_list)
start_expert_id_list = torch.randint(low=0, high=129, size=(test_cases, ), dtype=torch.int32)
start_expert_id_list = torch.minimum(start_expert_id_list, num_expert_list - expert_size_list)
start_expert_id_list = torch.maximum(start_expert_id_list, torch.zeros(test_cases, dtype=torch.int32))
multi_scale_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
need_gather_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
input_with_stride_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dynamic_quant_list = torch.randint(low=0, high=2, size=(test_cases, ), dtype=torch.int32)
dtype_list = torch.randint(low=0, high=10, size=(test_cases, ), dtype=torch.int32)
dtypes = [torch.half, torch.half, torch.float32]
if torch_mlu.mlu.is_bf16_supported():
dtypes += [torch.bfloat16, torch.bfloat16]
dtype_list = random.choices(dtypes, k=test_cases)
device = "mlu"
for i in range(test_cases):
num_tokens = num_tokens_list[i].item()
topk = topk_list[i].item()
hidden_size = hidden_size_list[i].item()
num_expert = num_expert_list[i].item()
expert_size = expert_size_list[i].item()
start_expert_id = start_expert_id_list[i].item()
multi_scale = multi_scale_list[i].item() == 1
need_gather = need_gather_list[i].item() == 1
input_with_stride = input_with_stride_list[i].item() == 1
dynamic_quant = dynamic_quant_list[i].item() == 1
dtype = dtype_list[i]
if not multi_scale or not torch_mlu.mlu.is_bf16_supported():
need_gather = False
inputs = gen_case(num_tokens,
topk,
hidden_size,
multi_scale,
need_gather,
num_expert,
expert_size,
start_expert_id,
dtype,
device)
input = inputs[0]
input_scale = inputs[1]
gather_ids = inputs[2]
token_count = inputs[3]
cusum_token_count = inputs[4]
gather_index_start_position = inputs[5]
effective_count = inputs[6]
if input_with_stride:
hidden_size = hidden_size - 64
input = input[..., hidden_size : ]
input_scale = input_scale[..., hidden_size : ].contiguous()
print("num_tokens={}, topk={}, hidden_size={}, num_expert={}, expert_size={}, "
"start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, "
"dynamic_quant={}, dtype={}, testing...".format(
num_tokens, topk, hidden_size, num_expert, expert_size, start_expert_id, \
multi_scale, need_gather, input_with_stride, dynamic_quant, dtype))
torch_quant, torch_output_scale = self.op_impl_base(input,
input_scale,
None,
token_count,
gather_ids,
gather_index_start_position,
None,
None,
dynamic_quant)
tmo_output_scale = None
if dynamic_quant:
tmo_output, tmo_output_scale = \
tmo.moe_quantize(input, input_scale, None, token_count,
gather_ids, gather_index_start_position,
None, None, dynamic_quant)
else:
tmo_output, = \
tmo.moe_quantize(input, input_scale, None, token_count,
gather_ids, gather_index_start_position,
None, None, dynamic_quant)
tmo_output = tmo_output[:effective_count]
if tmo_output_scale is not None:
tmo_output_scale = tmo_output_scale[:effective_count]
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
tmo_output.cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
if dynamic_quant:
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
tmo_output_scale.cpu().reshape(-1).float(),
0.003, use_MSE=True, use_RAE=True)
def test_interface(self):
channel = 16
dtype_list = [torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
for dtype in dtype_list:
input = torch.randn((2, 3, 2, channel), dtype=dtype, device="mlu")
input_scale = torch.randn((channel)).float().mlu()
print("test tmo.per_token_smooth_quantize...")
torch_quant, torch_scale = per_token_smooth_quantize_base(input, input_scale)
tmo_quant, tmo_scale = tmo.per_token_smooth_quantize(input, input_scale)
self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(),
0.01, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_scale.cpu().float(), tmo_scale.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
print("test tmo.quantize ...")
input_scale = input_scale * 100.0
torch_quant = quantize_base(input, input_scale)
tmo_quant = tmo.quantize(input, input_scale)
self.assertTensorsEqual(torch_quant.cpu().float(), tmo_quant.cpu().float(),
0.003, use_MSE=True, use_RAE=True)
print("test tmo.moe_quantize inplace ...")
num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \
1024, 5, 512, True, True, 32, 4, 4, torch.half, True
inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert,
expert_size, start_expert_id, dtype, 'mlu')
input = inputs[0]
input_scale = inputs[1]
gather_ids = inputs[2]
token_count = inputs[3]
cusum_token_count = inputs[4]
gather_index_start_position = inputs[5]
effective_count = inputs[6]
tmo_output = torch.empty(num_tokens*topk, hidden_size, dtype=torch.int8, device='mlu')
tmo_output_scale = torch.empty(num_tokens*topk, dtype=torch.float, device='mlu')
if 'MLU370' not in torch_mlu.mlu.get_device_name():
tmo.moe_quantize(input, input_scale, None, token_count,
gather_ids, gather_index_start_position,
tmo_output, tmo_output_scale, dynamic_quant)
tmo_output = tmo_output[:effective_count]
tmo_output_scale = tmo_output_scale[:effective_count]
torch_output = torch.empty_like(tmo_output)
torch_output_scale = torch.empty_like(tmo_output_scale)
self.op_impl_base(input, input_scale,
None, token_count, gather_ids, gather_index_start_position,
torch_output, torch_output_scale, dynamic_quant)
self.assertTensorsEqual(torch_output.cpu().reshape(-1).float(),
tmo_output.cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
tmo_output_scale.cpu().reshape(-1).float(),
0.003, use_MSE=True, use_RAE=True)
input = input.reshape(2, 4, 8, 16, -1)
input_scale = input_scale[0]
tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu')
tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu')
tmo.moe_quantize(input, input_scale, None, None, None, None,
tmo_output, tmo_output_scale, dynamic_quant)
torch_quant = torch.empty_like(tmo_output)
torch_output_scale = torch.empty_like(tmo_output_scale)
self.op_impl_base(input, input_scale,
None, None, None, None, torch_quant, torch_output_scale, dynamic_quant)
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
tmo_output.cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
tmo_output_scale.cpu().reshape(-1).float(),
0.003, use_MSE=True, use_RAE=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_prevent(self):
num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert, expert_size, start_expert_id, dtype, dynamic_quant = \
1024, 5, 512, True, True, 32, 4, 4, torch.half, True
inputs = gen_case(num_tokens, topk, hidden_size, multi_scale, need_gather, num_expert,
expert_size, start_expert_id, dtype, 'mlu')
input = inputs[0]
input_scale = inputs[1]
gather_ids = inputs[2]
token_count = inputs[3]
cusum_token_count = inputs[4]
gather_index_start_position = inputs[5]
effective_count = inputs[6]
tmo_output = torch.empty(input.size(), dtype=torch.int8, device='mlu')
tmo_output_scale = torch.empty(input.size()[:-1], dtype=torch.float, device='mlu')
func = tmo.moe_quantize
self.assertException("input must be mlu tensor.",
func, input.cpu(), input_scale, None, None, None, None,
None, None, dynamic_quant)
self.assertException(None,
func, input, input_scale, None, token_count, gather_ids,
gather_index_start_position.cpu(), None, None, dynamic_quant)
self.assertException("not support output_scale if dynamic_quant = false",
func, input, input_scale, None, token_count, gather_ids,
gather_index_start_position, None, tmo_output_scale, False)
self.assertException("input.dim() == 2 if has gather_index or token_count",
func, input.reshape(32, 32, -1), input_scale, None, None, gather_ids,
gather_index_start_position, None, None, dynamic_quant)
self.assertException("input.dim() >= 2",
func, input.reshape(-1), input_scale, None, None, None, None,
None, None, dynamic_quant)
self.assertException("output.dim() >= 2",
func, input, input_scale, None, token_count, gather_ids,
gather_index_start_position, tmo_output.reshape(-1), tmo_output_scale, True)
self.assertException("input and output must have the same shape",
func, input.reshape(2, 512, -1), input_scale, None, None, None, None,
tmo_output.reshape(32, 32, -1), None, dynamic_quant)
self.assertException("output_scale_shape must be equal to input_shape[0:-1]",
func, input.reshape(2, 512, -1), input_scale[0], None, None, None, None,
None, tmo_output_scale.reshape(32, 32), dynamic_quant)
self.assertException("gather_index must exist if gather_index_start_position has value",
func, input, input_scale, None, token_count, None,
gather_index_start_position, None, None, True)
self.assertException("gather_index.dim() == 1",
func, input, input_scale, None, token_count, gather_ids.reshape(1, -1),
gather_index_start_position, None, None, True)
def test_perf_case(self):
num_tokens_list = [1, 72, 512]
topk = 5
hidden_size_list = [2048, 4096, 5120, 8192]
# [num_expert, start_expert_id, expert_size]
expert_options_list = [[8, 0, 8], [32, 24, 8]]
multi_scale_list = [True, False]
need_gather_list = [True, False]
dynamic_quant_list = [True]
dtype_list = [torch.half, torch.bfloat16]
device = 'mlu'
args = product(num_tokens_list, hidden_size_list, expert_options_list,\
multi_scale_list, need_gather_list, dynamic_quant_list, dtype_list)
for num_tokens, hidden_size, expert_options, multi_scale, need_gather, dynamic_quant, dtype in args:
num_expert = expert_options[0]
start_expert_id = expert_options[1]
expert_size = expert_options[2]
if not multi_scale or not torch_mlu.mlu.is_bf16_supported():
need_gather = False
if not torch_mlu.mlu.is_bf16_supported():
continue
torch.manual_seed(444)
inputs = gen_case(num_tokens,
topk,
hidden_size,
multi_scale,
need_gather,
num_expert,
expert_size,
start_expert_id,
dtype,
device)
input = inputs[0]
input_scale = inputs[1]
gather_ids = inputs[2]
token_count = inputs[3]
cusum_token_count = inputs[4]
gather_index_start_position = inputs[5]
effective_count = inputs[6]
print("num_tokens={}, hidden_size={}, num_expert={}, expert_size={}, "
"start_expert_id={}, multi_scale={}, need_gather={}, input_with_stride={}, "
"dynamic_quant={}, dtype={}, testing...".format(
num_tokens, hidden_size, num_expert, expert_size, start_expert_id, \
multi_scale, need_gather, False, dynamic_quant, dtype))
torch_quant, torch_output_scale = self.op_impl_base(input,
input_scale,
None,
token_count,
gather_ids,
gather_index_start_position,
None,
None,
dynamic_quant)
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
loop = 10
for _ in range(loop):
if dynamic_quant:
tmo_output, tmo_output_scale = \
tmo.moe_quantize(input, input_scale, None, token_count,
gather_ids, gather_index_start_position,
None, None, dynamic_quant)
else:
tmo_output, = \
tmo.moe_quantize(input, input_scale, None, token_count,
gather_ids, gather_index_start_position,
None, None, dynamic_quant)
notify_end.record()
notify_end.synchronize()
time = notify_start.hardware_time(notify_end) / loop
tmo_output = tmo_output[:effective_count]
if tmo_output_scale is not None:
tmo_output_scale = tmo_output_scale[:effective_count]
print("time is: {:.1f}us".format(time))
self.assertTensorsEqual(torch_quant.cpu().reshape(-1).float(),
tmo_output.cpu().reshape(-1).float(),
0.01, use_MSE=True, use_RAE=True)
if dynamic_quant:
self.assertTensorsEqual(torch_output_scale.cpu().reshape(-1).float(),
tmo_output_scale.cpu().reshape(-1).float(),
0.003, use_MSE=True, use_RAE=True)
def test_inductor(self):
multi_scale_list = [True, False]
need_gather_list = [True, False] if 'MLU370' not in torch_mlu.mlu.get_device_name() else [False]
dynamic_quant_list = [True, False]
num_tokens, hidden_size, num_expert, start_expert_id, expert_size, topk, \
dtype, device = 1, 2048, 32, 24, 8, 5, torch.float16, 'mlu'
params = product(multi_scale_list, need_gather_list, dynamic_quant_list)
# quantize
print(f"check ops.quantize...")
input, input_scale, gather_ids, token_count, _, _, _ = gen_case(num_tokens,
topk,
hidden_size,
False,
False,
num_expert,
expert_size,
start_expert_id,
dtype,
device)
output = torch.empty(input.size(), dtype=torch.int8, device=device)
args = (input, input_scale, output, torch.Tensor(), None, None, None, None, 'per_token', False)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
# per_token_smooth_quantize
print(f"check ops.per_token_smooth_quantize...")
output_scale = torch.empty(input.size()[:-1], dtype=input_scale.dtype, device=device)
args = (input, input_scale, output, output_scale, None, token_count, None, None, 'per_token', True)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
# moe_quantize
for multi_scale, need_gather, dynamic_quant in params:
if not multi_scale and need_gather:
continue
print(f"check ops.moe_quantize multi_scale: {multi_scale}, need_gather: {need_gather}, dynamic_quant: {dynamic_quant} ...")
input, input_scale, gather_ids, token_count, _, \
gather_index_start_position, _ = gen_case(num_tokens,
topk,
hidden_size,
multi_scale,
need_gather,
num_expert,
expert_size,
start_expert_id,
dtype,
device)
output_shape = list(input.size())
output_scale_shape = list(input.size()[:-1])
if gather_ids is not None:
output_tokens = gather_ids.size(0)
output_shape[0] = output_tokens
output_scale_shape[0] = output_tokens
output = torch.empty(output_shape, dtype=torch.int8, device=device)
output_scale = torch.empty(output_scale_shape, dtype=input_scale.dtype, device=device) if dynamic_quant else None
args = (input, input_scale,
output, output_scale, None,
token_count, gather_ids, gather_index_start_position,
'per_token', dynamic_quant)
self.base_opcheck(torch.ops.torch_mlu_ops.smooth_quant, args)
if __name__ == '__main__':
exit(run_unittest(TestSmoothQuantOp))

View File

@@ -0,0 +1,217 @@
import torch
from torch_mlu import mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
from torch.nn import functional as F
def gen_mix_w4w8_param(bc, seq, k, n, experts_num, topk, data_type, has_bias, quant_wise):
bs = bc * seq
token_topk = bs * topk
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
sorted_expert_id, indices = expert_id.sort()
gather_idx = indices // topk
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
quant_group = k // quant_wise
quant_flag = random.choices([4,8], k=experts_num * quant_group)
b_count = (sum(quant_flag) // 4) * (quant_wise // 2) * n
b = torch.randint(-128, 127, (b_count,), dtype=torch.int32, device="mlu").to(torch.int8)
b_scale = torch.normal(0, 0.01, (quant_group, experts_num, n), device="mlu", dtype=torch.float32)
a = torch.randint(-128, 127, (bs, k), device="mlu", dtype=torch.int32).to(torch.int8)
a = a[gather_idx]
a_scale = torch.normal(0, 0.01, (bs,), device="mlu", dtype=torch.float32)
a_scale = a_scale[gather_idx]
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None
return a, b, token_count, None, c, None, None, a_scale, b_scale, data_type, bs, bias, quant_flag
def gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias=False, quant_bit = 8, quant_group = 1):
bs = batch * seq
token_topk = bs * topk
expert_id = torch.randint(experts_num, (token_topk,), device="mlu")
sorted_expert_id, indices = expert_id.sort()
gather_idx = indices // topk
gather_idx = gather_idx.to(torch.int32)
token_count = torch.bincount(sorted_expert_id, minlength=experts_num).to(torch.int32)
a = torch.randn(bs, k, device="mlu", dtype=data_type)
if not idx_mode:
a = a[gather_idx]
b = torch.randn(experts_num, n, k, device="mlu", dtype=data_type)
c = torch.randn(token_topk, n, device="mlu", dtype=data_type)
a, a_scale = QuantByRow(a, 8)
if idx_mode:
a_scale = a_scale[gather_idx]
b_shape = b.shape
b, b_scale = QuantByRow(b.view(-1, b.shape[-1]), quant_bit, quant_group)
b = b.view(b_shape)
if quant_bit == 4:
b = PairlyPackInt8(b)
b_scale = b_scale.view(experts_num, -1) if quant_group == 1 else b_scale.view(experts_num, -1, quant_group).permute(2, 0, 1).contiguous()
alpha = None
beta = None
bias = torch.randn(experts_num, n, device="mlu", dtype=data_type) if has_bias else None
gather_idx_ = gather_idx if idx_mode else None
quant_flag = None
return a, b, token_count, gather_idx_, c, alpha, beta, a_scale, b_scale, data_type, bs, bias, quant_flag
class TestSmoothQuantGroupGemmOp(BtTestCase):
def run_gen_case(self, dic):
dump_data = dic.pop('dump_data')
if dump_data:
self.launch(*dic.values())
else:
a = create_tensor_from_dic(dic['a'])
b = create_tensor_from_dic(dic['b'])
m_list = dic['m_list']['data']
expand_idx = dic['expand_idx']['data']
c = None if dic['c']['data'] is None else create_tensor_from_dic(dic['c'])
alpha = None if dic['alpha']['data'] is None else create_tensor_from_dic(dic['alpha'])
beta = None if dic['beta']['data'] is None else create_tensor_from_dic(dic['beta'])
a_scale = create_tensor_from_dic(dic['a_scale'], 0, 0.01)
b_scale = create_tensor_from_dic(dic['b_scale'], 0, 0.01)
dtype = dic['dtype']['data']
max_m = dic['max_m']['data']
bias = None if dic['bias']['data'] is None else create_tensor_from_dic(dic['bias'])
quant_flag = dic['quant_flag']['data']
self.launch(a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag)
def launch(self, *args):
total_m = args[2].sum().item()
torch_out = self.op_impl_base(*args)
tmo_out = ops.smooth_quant_group_gemm(*args)
self.assertTensorsEqual(tmo_out.cpu().float()[0:total_m], torch_out.cpu().float()[0:total_m], 0.006, use_MSE=True)
def op_impl_base(self, *args):
a, b, m_list, expand_idx, c, alpha, beta, a_scale, b_scale, dtype, max_m, bias, quant_flag = args
a = a.reshape(-1, a.size(-1))
if expand_idx is not None:
a = a[expand_idx]
total_m = m_list.sum().item()
a_list = a[:total_m].split(tuple(m_list))
c_list = []
if c is not None:
c = c.reshape(-1, c.size(-1))
c_list = c[:total_m].split(tuple(m_list))
if a_scale is not None:
a_scale_list = a_scale[:total_m].split(tuple(m_list))
k = a.shape[1]
n = b.size(1) if quant_flag is None else b_scale.shape[2]
if b_scale is not None and b_scale.dim() == 3: # for quant_grouped
b_scale = b_scale.transpose(0, 1).contiguous()
if quant_flag is not None:
quant_group = b_scale.shape[1]
group_wise = k // quant_group
quant_flag = torch.tensor(quant_flag).view(-1, quant_group)
b_offset_cu = torch.cumsum(quant_flag.sum(dim=1), dim=0) // 4 * (group_wise // 2) * n
b_offset_cu = torch.nn.functional.pad(b_offset_cu, (1,0), "constant", 0)
output_list = []
experts = b.size(0) if quant_flag is None else b_scale.size(0)
for i in range(experts):
if (a_list[i].size(0) > 0):
if a_scale is not None and b_scale is not None:
if quant_flag is None:
gemm_out = smooth_quant_matmul(a_list[i], a_scale_list[i], b[i], b_scale[i], dtype)
else:
gemm_out = smooth_quant_matmul_w4w8_mixed(a_list[i], a_scale_list[i],
b[b_offset_cu[i]:b_offset_cu[i+1]],
b_scale[i], dtype, quant_flag = quant_flag[i])
else:
gemm_out = F.linear(a_list[i], b[i])
if bias is not None:
gemm_out += bias[i]
if alpha is not None:
gemm_out *= alpha[i]
if beta is not None and c_list != []:
gemm_out += c_list[i] * beta[i]
output_list.append(gemm_out)
real_res = torch.cat(output_list, dim=0)
output = torch.empty(a.shape[0], n, device=real_res.device).to(real_res.dtype)
output[:total_m] = real_res
return output
def test_smooth_quant_group_gemm(self):
bs_list = [1, 3]
seq_list = [5, 8]
k_list = [512, 1024]
n_list = [512, 768, 2048]
expert_list = [8, 32]
topk_list = [2, 5]
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
has_bias_list = [True, False]
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list)
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias in args:
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, topk: {topk}, \
dtype: {data_type}, idx_mode: {idx_mode}, has_bias: {has_bias} testing...", flush=True)
param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
torch_out = self.op_impl_base(*param)
tmo_out = ops.smooth_quant_group_gemm(*param)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_sq_group_gemm_quant_group(self):
bs_list = [1, 3]
seq_list = [5, 8]
k_list = [512, 1024]
n_list = [512, 768, 2048]
expert_list = [8, 32]
topk_list = [2, 5]
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
idx_list = [False]
quant_bit_list = [4, 8]
quant_group_size_list = [128, 256]
has_bias_list = [True, False]
args = product(bs_list, seq_list, k_list, n_list, expert_list, topk_list, dtype_list, idx_list, has_bias_list, quant_bit_list, quant_group_size_list)
for batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group_size in args:
print(f"bs: {batch}, seq_len: {seq}, k: {k}, n: {n}, experts_num: {experts_num}, \
topk: {topk}, dtype: {data_type}, idx_mode: {idx_mode}, has_bias:{has_bias}, quant_bit: {quant_bit}, quant_group_size: {quant_group_size} testing...", flush=True)
quant_group = k // quant_group_size
param = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias, quant_bit, quant_group)
torch_out = self.op_impl_base(*param)
tmo_out = ops.smooth_quant_group_gemm(*param)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_sq_group_gemm_w4w8_mixed(self):
bs_l = [1, 3]
seq_l = [5, 8]
k_l = [1024, 2048, 3072]
n_l = [512, 768, 2048]
expert_l = [8, 32]
topk_l = [2, 5]
dtype_l = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
has_bias_l = [True, False]
group_wise_l = [128, 256, 512]
args = product(bs_l, seq_l, k_l, n_l, expert_l, topk_l, dtype_l, has_bias_l, group_wise_l)
for bc, seq, k, n, experts, topk, data_type, has_bias, group_wise in args:
print(f"bs: {bc}, seq_len: {seq}, k: {k}, n: {n}, experts: {experts}, \
topk: {topk}, dtype: {data_type}, has_bias: {has_bias}, group_wise: {group_wise}, testing...", flush=True)
param = gen_mix_w4w8_param(bc, seq, k, n, experts, topk, data_type, has_bias, group_wise)
torch_out = self.op_impl_base(*param)
tmo_out = ops.smooth_quant_group_gemm(*param)
self.assertTensorsEqual(tmo_out.cpu().float(), torch_out.cpu().float(), 0.006, use_MSE=True)
def test_inductor(self):
batch, seq, k, n, experts_num, topk = 1, 8, 1024, 2048, 8, 5
dtype_list = [torch.half, torch.bfloat16] if mlu.is_bf16_supported() else [torch.half]
idx_list = [False] if mlu.get_device_name() == 'MLU370' else [False, True]
has_bias_list = [True, False]
args = product( dtype_list, idx_list, has_bias_list)
for data_type, idx_mode, has_bias in args:
args = gen_tensor(batch, seq, k, n, experts_num, topk, data_type, idx_mode, has_bias)
new_args = list(args)[:9]
new_args.extend([args[-2], "half" if data_type == torch.half else "bfloat16", args[-1], None, args[-3]])
self.base_opcheck(torch.ops.torch_mlu_ops.group_gemm, new_args)
if __name__ == '__main__':
exit(run_unittest(TestSmoothQuantGroupGemmOp))

View File

@@ -0,0 +1,101 @@
import torch
import torch_mlu
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
dtype_dict = {
torch.half: "half",
torch.bfloat16: "bfloat16",
}
dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16
class TestSmoothQuantMatmulOp(BtTestCase):
def op_impl_base(self, *args):
a, a_scale, b, b_scale, dtype, bias, c, act_mode, alpha, beta, use_hp_active = args
if a_scale is not None:
a = torch.mul(a, a_scale.unsqueeze(-1)).to(dtype)
if b_scale is not None:
b = torch.mul(b, b_scale.unsqueeze(-1)).to(dtype)
output = torch.matmul(a, b.permute(1,0))
if bias is not None:
output += bias
output = torch.mul(output, alpha)
if c is not None:
residual = torch.mul(c, beta)
output = torch.add(output, residual)
if act_mode != "none":
act = act_mode_dict[act_mode]
output = act(output)
return output
def test_smooth_quant_matmul(self):
m_list = [32, 64, 128]
n_list = [64, 128, 256]
k_list = [128, 256, 512]
has_bias_list = [True, False]
has_c_list = [True, False]
act_mode_list = ["none", "silu", "gelu"]
use_hp_active_list = [True, False]
args = product(m_list, n_list, k_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list)
for m, n, k, has_bias, has_c, act_mode, dtype, use_hp_active in args:
if has_c and act_mode != "none":
continue
a = torch.randn(m, k, device="mlu", dtype=dtype)
b = torch.randn(n, k, device="mlu", dtype=dtype)
bias, c = None, None
if has_bias:
bias = torch.randn(n, device="mlu", dtype=dtype)
if has_c:
c = torch.randn(m, n, device="mlu", dtype=dtype)
input_smooth = torch.randn(k, device="mlu", dtype=torch.float).abs() + 0.1
quant_input, input_scale = QuantByRow(a * input_smooth, 8)
quant_weight, weight_scale = QuantByRow(b / input_smooth, 8)
torch_output = self.op_impl_base(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c,
act_mode, 1.0, 1.0, use_hp_active)
tmo_output = ops.smooth_quant_matmul(quant_input, input_scale, quant_weight, weight_scale, dtype, bias, c,
act_mode, 1.0, 1.0, use_hp_active)
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.006, use_MSE=True)
def test_inductor(self):
has_bias_list = [True, False]
has_c_list = [True, False]
act_mode_list = ["none", "silu", "gelu"]
arg = product(has_c_list, has_bias_list, act_mode_list, dtype_list)
for has_c, has_bias, act_mode, dtype in arg:
if has_c and act_mode != "none":
continue
print(f"===has_c: {has_c}, has_bias: {has_bias}, act_mode: {act_mode}, dtype: {dtype}===")
M, K, N = 2, 16, 32
quant_bit_size, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, True, 1., 0.8, 0.3, False, True
a = torch.randint(0, 10, (M, K), dtype=torch.int8).mlu()
b = torch.randint(0, 10, (N, K), dtype=torch.int8).mlu()
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_c else None
bias = torch.randn(N, device="mlu", dtype=dtype) if has_bias else None
a_scale = torch.randn(M, device="mlu", dtype=torch.float)
b_scale = torch.randn(N, device="mlu", dtype=torch.float)
a_zero, b_zero, c_zero = None, None, None
c_scale, gemm_output_scale, gemm_output_zero = None, None, None
quant_algo, a_quant_layout, b_quant_layout = "smooth_quant", "quantize_per_token", "quantize_per_channel"
str_dtype = dtype_dict[dtype]
args = [a, a_scale, a_zero,
b, b_scale, b_zero,
bias, c, c_scale, c_zero,
gemm_output_scale, gemm_output_zero,
str_dtype, None, quant_algo,
a_quant_layout, b_quant_layout,
quant_bit_size, act_mode, use_hp_active, act_coef,
alpha, beta, trans_a, trans_b,]
self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args)
if __name__ == '__main__':
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestSmoothQuantMatmulOp))

View File

@@ -0,0 +1,88 @@
import random
import torch
import torch_mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
import os
def gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair):
shape = (num_blocks, num_heads, block_size, head_size)
if dtype in {torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}:
info = torch.iinfo(dtype)
if cpy == "mlu to mlu":
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
elif cpy == "mlu to cpu":
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu()
elif cpy == "cpu to mlu":
src = torch.randint(info.min, info.max, size=shape, dtype=dtype).cpu()
dst = torch.randint(info.min, info.max, size=shape, dtype=dtype).mlu()
else:
print("unkown copy direction.", flush=True)
exit(1)
else:
if cpy == "mlu to mlu":
src = torch.randn(size=shape, dtype=dtype).mlu()
dst = torch.randn(size=shape, dtype=dtype).mlu()
elif cpy == "mlu to cpu":
src = torch.randn(size=shape, dtype=dtype).mlu()
dst = torch.randn(size=shape, dtype=dtype).cpu()
elif cpy == "cpu to mlu":
src = torch.randn(size=shape, dtype=dtype).cpu()
dst = torch.randn(size=shape, dtype=dtype).mlu()
else:
print("unkown copy direction.", flush=True)
exit(1)
values = list(range(num_pair))
random.shuffle(values)
src_to_dst = {key: value for key, value in zip(range(num_pair), values)}
return dst, src, src_to_dst
class TestSwapBlocksOp(BtTestCase):
def op_impl_base(self, *args):
dst, src, block_mapping = args
for key, value in block_mapping.items():
dst[value] = src[key]
return dst
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test swap blocks due to ASan issues")
def test_swap_blocks(self):
num_blocks_list = [3600]
num_heads_list = [8]
head_size_list = [64,128]
block_size_list = [16]
num_pairs_list = [6,512]
types = [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.half, torch.float]
if torch_mlu.mlu.is_bf16_supported():
types.append(torch.bfloat16)
cpys = ["mlu to mlu", "mlu to cpu", "cpu to mlu"]
args = product(num_blocks_list, num_heads_list, block_size_list, head_size_list, types, cpys, num_pairs_list)
for num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair in args:
print("num_blocks: {}, num_heads: {}, block_size: {}, head_size: {}, dtype: {}, dir: {}, num_pairs: {} testing..."
.format(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair), flush=True)
dst, src, src_to_dst = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair)
ref_src, ref_dst = src.clone(), dst.clone()
# cpu
self.op_impl_base(ref_dst, ref_src, src_to_dst)
# mlu
ops.swap_blocks(dst, src, src_to_dst)
# diff
self.assertTensorsEqual(src.cpu().float(), ref_src.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
self.assertTensorsEqual(dst.cpu().float(), ref_dst.cpu().float(), 0, use_MSE=True, use_RAE=True, use_RMA=True)
@unittest.skipIf('TMO_MEM_CHECK' in os.environ, "Skipping test_prevent due to ASan issues")
def test_inductor(self):
num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair = 3600, 8, 16, 64, torch.half, "mlu to mlu", 512
dst, src, block_mapping = gen_args(num_blocks, num_heads, block_size, head_size, dtype, cpy, num_pair)
self.base_opcheck(torch.ops.torch_mlu_ops.swap_blocks, (dst, src, block_mapping))
if __name__ == '__main__':
exit(run_unittest(TestSwapBlocksOp))

View File

@@ -0,0 +1,150 @@
import torch
from torch_mlu import mlu
import unittest
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
from torch.nn import functional as F
import time
def gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack):
if not pack:
out = torch.randn(batch, max_seq_len, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(batch, block_seq_len, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
seq_offset = torch.randint(low=0, high=(max_seq_len - block_seq_len + 1), size=(batch, ), dtype=torch.int32, device="mlu")
cu_seqs = None
block_cu_seqs = None
else:
seq_lens = torch.randint(low=1, high=(max_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.randint(low=1, high=(block_seq_len + 1), size=(batch, ), dtype=torch.int32)
block_seq_lens = torch.minimum(seq_lens, block_seq_lens)
seq_offset = torch.zeros_like(seq_lens)
for i in range(batch):
seq_offset[i] = torch.randint(low=0, high=seq_lens[i]-block_seq_lens[i]+1, size=(1,), dtype=torch.int32)
seq_offset = seq_offset.mlu()
cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(seq_lens, dim=0))).to(torch.int32).mlu()
block_cu_seqs = torch.cat((torch.tensor([0]),torch.cumsum(block_seq_lens, dim=0))).to(torch.int32).mlu()
total_seqs = torch.sum(seq_lens)
block_total_seqs = torch.sum(block_seq_lens)
out = torch.randn(total_seqs, head_num, head_size, device="mlu", dtype=dtype)
lse = torch.randn(batch, head_num, max_seq_len, device="mlu", dtype=torch.float32)
block_out = torch.randn(block_total_seqs, head_num, head_size, device="mlu", dtype=dtype)
block_lse = torch.randn(batch, head_num, block_seq_len, device="mlu", dtype=torch.float32)
return (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
class TestUpdateOutAndLse(BtTestCase):
def op_impl_base(self, *args):
out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs = args
return update_out_and_lse_torch(out, lse, block_out, block_lse, seq_offsets, cu_seqs, block_cu_seqs)
def test_update_out_and_lse(self):
test_case_num = 10
dtype_choice = [torch.float16, torch.float32]
random.seed(time.time())
count = test_case_num
while count > 0:
batch = random.randint(1, 16 + 1)
head_num = random.randint(1, 16 + 1)
head_size = random.randint(1, 512 + 1)
block_seq_len = random.choice([1, random.randint(2, 2048 + 1)])
max_seq_len = 1 if block_seq_len == 1 else max(random.randint(2, 2048 + 1), block_seq_len)
pack = random.choice([True, False])
# 避免测试出现mlu显存不够
if batch * head_num * head_size * max_seq_len > 10 * 1024 * 1024 * 1024:
continue
if torch_mlu.mlu.is_bf16_supported():
dtype_choice.append(torch.bfloat16)
else:
if batch * head_num * head_size * max_seq_len > 1 * 1024 * 1024 * 1024:
continue
dtype = random.choice(dtype_choice)
print(f"test_update_out_and_lse] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
count -= 1
print(f"[test_update_out_and_lse] {test_case_num} cases test pass")
def test_combine_ring_attn(self):
batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack =\
1, 8, 128, 8192, 8192, torch.bfloat16, True
if not torch_mlu.mlu.is_bf16_supported():
dtype = torch.float16
print(f"[test_combine_ring_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
# total_seq和block_total_seq为所有batch的真实seq长度之和假设所有seq的
# out: [sum(out_seqs), 64, 128] block_out [sum(block_out_seqs), 64, 128]
# lse: [1, 8, 8192] block_lse [1, 8, 8192]
# seq_offset: [128]
# cu_seqs block_cu_seqs: [128 + 1]
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
print("[test_combine_ring_attn] pass")
def test_combine_decoder_attn(self):
batch_list = [16, 128]
head_num_list = [64]
head_size_list = [128]
block_seq_len_list = [1]
max_seq_len_list = [1]
dtype_list = [torch.float16]
pack_list = [False]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
args = product(batch_list, head_num_list, head_size_list, block_seq_len_list, max_seq_len_list,
dtype_list, pack_list)
for batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack in args:
print(f"[test_combine_decoder_attn] {batch}, {head_num}, {head_size}, {block_seq_len}, {max_seq_len}, {dtype}, {pack}")
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
# out: [128, 1, 64, 128] block_out [128, 1, 64, 128]
# lse: [128, 64, 1] block_lse [128, 1]
# seq_offset cu_seqs block_cu_seqs: None
out_torch, lse_torch = self.op_impl_base(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
ops.update_out_and_lse(out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs)
self.assertTensorsEqual(lse.cpu().float(), lse_torch.cpu().float(), 0.005, use_MSE=True)
self.assertTensorsEqual(out.cpu().float(), out_torch.cpu().float(), 0.005, use_MSE=True)
print("[test_combine_decoder_attn] pass")
def test_inductor(self):
pack_list = [True, False]
for pack in pack_list:
batch, head_num, head_size, dtype = 16, 8, 128, torch.float16
if pack:
block_seq_len, max_seq_len = 1024, 2048
else:
block_seq_len, max_seq_len = 1, 1
out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs = \
gen_tenosr(batch, head_num, head_size, block_seq_len, max_seq_len, dtype, pack)
self.base_opcheck(torch.ops.torch_mlu_ops.update_out_and_lse, (out, lse, block_out, block_lse, seq_offset, cu_seqs, block_cu_seqs))
print("[test_update_out_and_lse] test_inductor check pass")
if __name__ == '__main__':
exit(run_unittest(TestUpdateOutAndLse))

View File

@@ -0,0 +1,113 @@
import torch
import torch_mlu
import torch_mlu_ops as ops
from common_utils import *
from itertools import product
dtype_dict = {
torch.half: "half",
torch.bfloat16: "bfloat16",
}
dtype_list = [torch.half]
if torch_mlu.mlu.is_bf16_supported():
dtype_list.append(torch.bfloat16)
# M<=INT32_MAX, nk<=INT32_MAX, k>=16, n>=16
class TestWeightOnlyQuantMatmulOp(BtTestCase):
def op_impl_base(self, *args):
a, b, scale, zero, bias, c, act_mode, quant_bit_size, alpha, beta, use_hp_active = args
if quant_bit_size == 4:
n = b.shape[0]
b = UnpackInt4(b).view(n, -1)
if scale is not None:
if scale.dim() == 2:
group_size = b.size(1) // scale.size(1)
scale_bd = scale.unsqueeze(-1).repeat(1, 1, group_size).reshape(b.shape)
else:
scale_bd = scale.unsqueeze(-1)
b = torch.mul(b, scale_bd).to(a.dtype)
output = torch.matmul(a, b.permute(1,0))
if bias is not None:
output += bias
output = torch.mul(output, alpha)
if c is not None:
residual = torch.mul(c, beta)
output = torch.add(output, residual)
if act_mode != "none":
act = act_mode_dict[act_mode]
output = act(output)
return output
def test_weight_only_quant_matmul(self):
m_list = [32, 64, 128]
n_list = [128, 256, 512]
group_list = [1, 2]
k_list = [128, 256, 512]
quant_bit_list = [8, 4]
has_bias_list = [True, False]
has_c_list = [True, False]
act_mode_list = ["none", "silu", "gelu"]
use_hp_active_list = [True, False]
args = product(m_list, n_list, k_list, group_list, quant_bit_list, has_bias_list, has_c_list, act_mode_list, dtype_list, use_hp_active_list)
for m, n, k, group, quant_bit, has_bias, has_c, act_mode, dtype, use_hp_active in args:
if has_c and act_mode != "none":
continue
a = torch.randn(m, k, device="mlu", dtype=dtype)
b = torch.randn(n, k, device="mlu", dtype=dtype)
bias, c = None, None
zero = None
if has_bias:
bias = torch.randn(n, device="mlu", dtype=dtype)
if has_c:
c = torch.randn(m, n, device="mlu", dtype=dtype)
quant_weight_int8, weight_scale = QuantByRow(b, quant_bit, group)
if group != 1:
if act_mode != "none":
continue
weight_scale = weight_scale.to(a.dtype)
if quant_bit == 4:
quant_weight_int4 = PairlyPackInt8(quant_weight_int8)
args = (a, quant_weight_int4 if quant_bit == 4 else quant_weight_int8,
weight_scale, zero, bias, c, act_mode, quant_bit, 1.0, 1.0, use_hp_active)
torch_output = self.op_impl_base(*args)
tmo_output = ops.weight_only_quant_matmul(*args)
self.assertTensorsEqual(tmo_output.cpu().float(), torch_output.cpu().float(), 0.004, use_MSE=True)
def test_inductor(self):
M, K, N, group_num = 2, 256, 32, 4
quant_bit_size, act_mode, use_hp_active, act_coef, alpha, beta, trans_a, trans_b = 8, 'none', True, 1., 0.8, 0.3, False, True
has_res_list = [True, False]
group_quant_list = [True, False]
args = product(has_res_list, group_quant_list, dtype_list)
for has_res, group_quant, dtype in args:
print(f"==has_res: {has_res}, group_quant: {group_quant}, dtype: {dtype}==")
a = torch.randn((M, K), dtype=dtype).mlu()
b = torch.randint(-128, 127, (N, K), dtype=torch.int8).mlu()
c = torch.randn(M, N, device="mlu", dtype=dtype) if has_res else None
bias = torch.randn(N, device="mlu", dtype=dtype)
group_wise_scale = torch.randn((N, group_num), device="mlu", dtype=dtype)
b_quant_layout = "quantize_group_wise" if group_quant else "quantize_per_channel"
b_scale = group_wise_scale if group_quant else None
gemm_output_scale = None if group_quant else torch.randn(N, device="mlu", dtype=torch.float)
a_scale, a_zero, b_zero, c_zero = None, None, None, None
c_scale, gemm_output_zero = None, None
quant_algo, a_quant_layout = "weight_only", "quantize_none"
dtype_str = dtype_dict[dtype]
args = [a, a_scale, a_zero,
b, b_scale, b_zero,
bias, c, c_scale, c_zero,
gemm_output_scale, gemm_output_zero,
dtype_str, None, quant_algo,
a_quant_layout, b_quant_layout,
quant_bit_size, act_mode, use_hp_active, act_coef,
alpha, beta, trans_a, trans_b,]
self.base_opcheck(torch.ops.torch_mlu_ops.quant_matmul, args)
if __name__ == '__main__':
random.seed(0)
torch.manual_seed(0)
exit(run_unittest(TestWeightOnlyQuantMatmulOp))