This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/generate_alibi_slope.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length) {
MLU_TENSOR_CHECK_FATAL(true_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
int head_num = tp_head_num * tp_num;
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(true_seq_lens.device());
torch::Tensor output = torch::zeros(
{batch, head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
for (int tp_id = 0; tp_id < tp_num; tp_id++) {
torch::Tensor tp_slopes = torch::zeros(
{batch, tp_head_num}, torch::dtype(torch::kFloat32).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeGenerateAlibiSlope(
queue, tp_slopes.data_ptr(), true_seq_lens.data_ptr(), batch, tp_id * tp_head_num,
tp_head_num, head_num, max_sequence_length, use_dynamic));
for (int batch_id = 0; batch_id < batch; batch_id++) {
CNRT_CHECK(
cnrtMemcpyAsync((float *)output.data_ptr() + batch_id * head_num + tp_id * tp_head_num,
(float *)tp_slopes.data_ptr() + batch_id * tp_head_num,
tp_head_num * sizeof(float), queue, cnrtMemcpyDevToDev));
}
}
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,83 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/create_cos_sin_table.mluh"
#include <cnnl.h>
#include <torch/extension.h>
#include "register_api.h"
#include "utils.h"
namespace {
constexpr int DYNAMIC_NTK_SCALING = 2;
}
namespace tmo {
namespace kernel_test_api {
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved) {
MLU_TENSOR_CHECK_FATAL(rotary_emb_alpha_cached);
MLU_TENSOR_CHECK_FATAL(seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dtype =
torch2cnnlDataType(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()));
std::shared_ptr<torch::Device> dev =
std::make_shared<torch::Device>(rotary_emb_alpha_cached.device());
int batch = seq_lens.size(0);
torch::Tensor cos_sin_table =
rotary_scaling_type == DYNAMIC_NTK_SCALING
? torch::zeros({batch, rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false))
: torch::zeros({rotary_seq_len, rotary_stride},
torch::dtype(torch::typeMetaToScalarType(rotary_emb_alpha_cached.dtype()))
.device(*dev)
.requires_grad(false));
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeCreateCosSinTable(
queue, cos_sin_table.data_ptr(),
reinterpret_cast<float *>(rotary_emb_alpha_cached.data_ptr()),
reinterpret_cast<int *>(seq_lens.data_ptr()), max_position_embeddings, batch, batch_stride,
rotary_seq_len, rotary_dim, rotary_stride, rotary_base, rotary_scaling, rotary_scaling_type,
interleaved, dtype));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return cos_sin_table;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,48 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(weight);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
int batch = input.size(0);
int seq = input.size(1);
int total_vocab_size = weight.size(0);
int hidden_size = weight.size(1);
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(weight.dtype()));
torch::Tensor output = torch::zeros(
{batch, seq, hidden_size},
torch::dtype(torch::typeMetaToScalarType(weight.dtype())).device(*dev).requires_grad(false));
TMO_KERNEL_CHECK_FATAL(invokeEmbedding(queue, weight.data_ptr(), input.data_ptr(),
output.data_ptr(), dtype, vocab_offset, vocab_part,
total_vocab_size, hidden_size, batch * seq));
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/operate_cu_seq_lens.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num) {
MLU_TENSOR_CHECK_FATAL(cu_seq_lens);
MLU_TENSOR_CHECK_FATAL(sliced_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeSliceCuSeqlens(queue, (int *)cu_seq_lens.data_ptr(),
(int *)sliced_cu_seq_lens.data_ptr(), batch,
parallel_num));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len) {
MLU_TENSOR_CHECK_FATAL(gen_cu_seq_lens);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
size_t io_bytes = 0;
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeGenerateCuSeqlens(queue, (int *)gen_cu_seq_lens.data_ptr(), seq_len,
parallel_num, is_causal_mask, is_kv_seq_len));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,67 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/quantize.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride) {
MLU_TENSOR_CHECK_FATAL(dst);
MLU_TENSOR_CHECK_FATAL(scale);
MLU_TENSOR_CHECK_FATAL(src);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
cnnlDataType_t dst_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(dst.dtype()));
cnnlDataType_t scale_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(scale.dtype()));
cnnlDataType_t src_dtype = torch2cnnlDataType(torch::typeMetaToScalarType(src.dtype()));
size_t io_bytes =
bs * seq_len * head_num * head_size * (dtype_size_map[src_dtype] + dtype_size_map[dst_dtype]);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
TMO_KERNEL_CHECK_FATAL(invokeMluQuantizePerHead(
queue, (void *)dst.data_ptr(), (void *)scale.data_ptr(), (void *)src.data_ptr(), dst_dtype,
scale_dtype, src_dtype, bs, seq_len, head_num, head_size, dst_bs_stride, dst_seq_stride,
dst_head_stride, src_bs_stride, src_seq_stride, src_head_stride));
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,82 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_REGISTER_API_H_
#define TEST_KERNELS_PYTEST_REGISTER_API_H_
#include <torch/extension.h>
#include <torch/torch.h>
namespace tmo {
namespace kernel_test_api {
at::Tensor embedding_test(const at::Tensor &input,
const at::Tensor &weight,
const int vocab_offset,
const int vocab_part);
at::Tensor alibi_slope_test(const at::Tensor &true_seq_lens,
const int batch,
const int tp_head_num,
const int tp_num,
const bool use_dynamic,
const int max_sequence_length);
at::Tensor create_cos_sin_table_test(at::Tensor &rotary_emb_alpha_cached,
const at::Tensor &seq_lens,
const int max_position_embeddings,
const int batch_stride,
const int rotary_seq_len,
const int rotary_dim,
const int rotary_stride,
const float rotary_scaling,
const int rotary_scaling_type,
const float rotary_base,
const bool interleaved);
void slice_cu_seq_lens_test(const at::Tensor &cu_seq_lens,
at::Tensor &sliced_cu_seq_lens,
int batch,
int parallel_num);
void generate_cu_seq_lens_test(at::Tensor &gen_cu_seq_lens,
int seq_len,
int parallel_num,
bool is_causal_mask,
bool is_kv_seq_len);
void per_head_quantize_test(at::Tensor &dst,
at::Tensor &scale,
const at::Tensor &src,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride);
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset);
} // namespace kernel_test_api
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_REGISTER_API_H_

View File

@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/extension.h>
#include "register_api.h"
PYBIND11_MODULE(btunittests, m) {
m.def("embedding_test", &tmo::kernel_test_api::embedding_test, "Test embedding kernel function.");
m.def("alibi_slope_test", &tmo::kernel_test_api::alibi_slope_test,
"Test alibi_slope kernel function.");
m.def("create_cos_sin_table_test", &tmo::kernel_test_api::create_cos_sin_table_test,
"Test create_cos_sin_table_test kernel function.");
m.def("slice_cu_seq_lens_test", &tmo::kernel_test_api::slice_cu_seq_lens_test,
"Test slice_cu_seq_lens_test kernel function.");
m.def("generate_cu_seq_lens_test", &tmo::kernel_test_api::generate_cu_seq_lens_test,
"Test generate_cu_seq_lens_test kernel function.");
m.def("per_head_quantize_test", &tmo::kernel_test_api::per_head_quantize_test,
"Test per_head_quantize_test kernel function.");
m.def("rotary_embedding_test", &tmo::kernel_test_api::rotary_embedding_test,
"Test rotary_embedding kernel function.");
}

View File

@@ -0,0 +1,202 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cnnl.h>
#include <torch/extension.h>
#include "kernels/rotary_embedding.mluh"
#include "register_api.h"
#include "utils.h"
namespace tmo {
namespace kernel_test_api {
at::Tensor rotary_embedding_test(const at::Tensor &input,
const at::Tensor &sin_table,
const at::Tensor &cos_table,
const c10::optional<at::Tensor> &seq_offsets,
const c10::optional<at::Tensor> &cu_seq_lens,
const bool &interleaved,
const bool &discrete,
const bool &dynamic_ntk,
const bool &rope_2d,
const int &max_seq_len,
const bool &no_offset) {
MLU_TENSOR_CHECK_FATAL(input);
MLU_TENSOR_CHECK_FATAL(sin_table);
MLU_TENSOR_CHECK_FATAL(cos_table);
bool has_seq_offsets = seq_offsets.has_value();
bool has_cu_seq_lens = cu_seq_lens.has_value();
int *seq_offsets_ptr =
has_seq_offsets ? reinterpret_cast<int *>(seq_offsets.value().data_ptr()) : nullptr;
int *cu_seq_lens_ptr =
has_cu_seq_lens ? reinterpret_cast<int *>(cu_seq_lens.value().data_ptr()) : nullptr;
int batch = 0;
int total_seq_len = 0;
int head_size = input.size(-1);
if (input.dim() == 3) { // pack mode
TORCH_CHECK(has_cu_seq_lens,
"input has 3 dims: (total_seq_len, head_num, head_size),"
" which means pack mode, cu_seqlens should not be None");
total_seq_len = input.size(0);
batch = cu_seq_lens.value().size(0) - 1;
} else if (input.dim() == 4) {
TORCH_CHECK(!has_cu_seq_lens,
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, cu_seqlens should be None");
TORCH_CHECK(max_seq_len == input.size(1),
"input has 4 dims: (batch_size, seq_len, head_num, head_size),"
" which means pad mode, max_seqlen must be equtals to input.size(1)");
batch = input.size(0);
total_seq_len = batch * input.size(1);
} else {
TORCH_CHECK(false, "input only support 3 or 4 dims");
}
const int rope_seqlen = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
const int rope_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (has_seq_offsets) {
if (discrete) {
CHECK_SHAPE(seq_offsets.value(), total_seq_len);
} else {
CHECK_SHAPE(seq_offsets.value(), batch);
}
}
if (dynamic_ntk) {
CHECK_SHAPE(sin_table, batch, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, batch, rope_seqlen, rope_dim);
} else {
CHECK_SHAPE(sin_table, rope_seqlen, rope_dim);
CHECK_SHAPE(cos_table, rope_seqlen, rope_dim);
}
#endif
TORCH_CHECK(input.stride(-1) == 1, "input last dim must be contiguous");
if (dynamic_ntk) {
TORCH_CHECK(sin_table.stride(1) == cos_table.stride(1),
"sin_table second stride must be equal to cos_table second stride");
} else {
TORCH_CHECK(sin_table.stride(0) == cos_table.stride(0),
"sin_table first stride must be equal to cos_table second stride");
}
if (has_seq_offsets) {
TORCH_CHECK(seq_offsets.value().is_contiguous(), "seq_offsets must be contiguous");
}
if (has_cu_seq_lens) {
TORCH_CHECK(cu_seq_lens.value().is_contiguous(), "cu_seq_lens must be contiguous");
}
// auto output = input;
// if (is_qkv) {
// }
int dims = input.dim();
int head_num = input.size(dims - 2); // head_num_qk
TORCH_CHECK(head_size <= 256, "only support input head_size <= 256");
int head_dim = input.size(dims - 1);
int rotary_seq_len = dynamic_ntk ? sin_table.size(1) : sin_table.size(0);
int rotary_dim = dynamic_ntk ? sin_table.size(2) : sin_table.size(1);
#if 0
if (rope_2d) {
total_seq_len = input.size(0);
rotary_seq_len *= 2;
}
#endif
int rotary_stride = dynamic_ntk ? sin_table.stride(1) : sin_table.stride(0);
size_t input_seq_stride = input.stride(dims - 3);
size_t input_head_stride = input.stride(dims - 2);
cnnlHandle_t handle = torch_mlu::getCurrentHandle();
cnrtQueue_t queue;
cnnlGetQueue(handle, &queue);
#if 0
std::shared_ptr<torch::Device> dev = std::make_shared<torch::Device>(input.device());
std::cout
<< " batch * max_seq_len : " << batch * max_seq_len
<< " max_seq_len : " << max_seq_len
<< " total_seq_len : " << total_seq_len
<< " head_num : " << head_num
<< " head_size : " << head_size << "\n";
// at::IntArrayRef output_shape = {batch, max_seq_len, head_num, head_size};
// if (input.dim() == 3) {
// output_shape = {batch * max_seq_len, head_num, head_size};
// }
#endif
cnnlDataType_t dtype = torch2cnnlDataType(torch::typeMetaToScalarType(input.dtype()));
auto output = input;
size_t output_seq_stride = output.stride(dims - 3);
size_t output_head_stride = output.stride(dims - 2);
#if 0
std::cout << " <<<< batch: " << batch << "\n";
std::cout << " <<<< max_seq_len: " << max_seq_len << "\n";
std::cout << " <<<< total_seq_len: " << total_seq_len << "\n";
std::cout << " <<<< head_num: " << head_num << "\n";
std::cout << " <<<< head_size: " << head_size << "\n";
std::cout << " <<<< rotary_seq_len_: " << rotary_seq_len << "\n";
std::cout << " <<<< rotary_dim: " << rotary_dim << "\n";
std::cout << " <<<< rotary_stride: " << rotary_stride << "\n";
std::cout << " <<<< input_seq_stride: " << input_seq_stride << "\n";
std::cout << " <<<< input_head_stride: " << input_head_stride << "\n";
std::cout << " <<<< output_seq_stride: " << output_seq_stride << "\n";
std::cout << " <<<< output_head_stride: " << output_head_stride << "\n";
std::cout << " <<<< interleaved: " << interleaved << "\n";
std::cout << " <<<< discrete: " << discrete << "\n";
std::cout << " <<<< dynamic_ntk: " << dynamic_ntk << "\n";
#endif
size_t io_bytes =
(size_t)total_seq_len * head_num * (head_size + rotary_dim) * dtype_size_map[dtype] * 2;
io_bytes += (discrete && !no_offset) ? total_seq_len * sizeof(int) : batch * sizeof(int);
cnrtNotifier_t time_begin;
cnrtNotifier_t time_end;
CNRT_CHECK(cnrtNotifierCreate(&time_begin));
CNRT_CHECK(cnrtNotifierCreate(&time_end));
CNRT_CHECK(cnrtPlaceNotifier(time_begin, queue));
if (rope_2d) {
TMO_KERNEL_CHECK_FATAL(invokeGlm6BRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, total_seq_len, head_num, head_size,
rotary_seq_len, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, dtype));
} else {
TMO_KERNEL_CHECK_FATAL(invokeRotaryEmbedding(
queue, output.data_ptr(), input.data_ptr(), sin_table.data_ptr(), cos_table.data_ptr(),
seq_offsets_ptr, cu_seq_lens_ptr, batch, max_seq_len, head_num, head_size, rotary_seq_len,
rotary_dim, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, discrete, dynamic_ntk, dtype));
}
CNRT_CHECK(cnrtPlaceNotifier(time_end, queue));
cnrtQueueSync(queue);
float usec = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(time_begin, time_end, &usec));
print_info(usec, io_bytes);
return output;
}
} // namespace kernel_test_api
} // namespace tmo

View File

@@ -0,0 +1,133 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
#define TEST_KERNELS_PYTEST_UTILS_H_
#include <cnrt.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include "aten/cnnl/cnnlHandle.h"
#include "common/utils.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "kernels/kernel_utils.h"
namespace tmo {
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
switch (dtype) {
case torch::kFloat32:
return CNNL_DTYPE_FLOAT;
case torch::kFloat16:
return CNNL_DTYPE_HALF;
case torch::kFloat64:
return CNNL_DTYPE_DOUBLE;
case torch::kInt8:
return CNNL_DTYPE_INT8;
case torch::kInt16:
return CNNL_DTYPE_INT16;
case torch::kInt32:
return CNNL_DTYPE_INT32;
case torch::kInt64:
return CNNL_DTYPE_INT64;
case torch::kUInt8:
return CNNL_DTYPE_UINT8;
case torch::kBool:
return CNNL_DTYPE_BOOL;
case torch::kBFloat16:
return CNNL_DTYPE_BFLOAT16;
default:
throw std::runtime_error("Unsupported torch::Dtype");
}
}
static constexpr int dtype_size_map[] = {
[CNNL_DTYPE_INVALID] = 0,
[CNNL_DTYPE_HALF] = 2,
[CNNL_DTYPE_FLOAT] = 4,
[CNNL_DTYPE_INT8] = 1,
[CNNL_DTYPE_INT16] = 2,
[CNNL_DTYPE_INT31] = 4,
[CNNL_DTYPE_INT32] = 4,
[CNNL_DTYPE_UINT8] = 1,
[CNNL_DTYPE_BOOL] = 1,
[CNNL_DTYPE_INT64] = 8,
[10] = 0,
[CNNL_DTYPE_UINT32] = 4,
[CNNL_DTYPE_UINT64] = 8,
[CNNL_DTYPE_UINT16] = 2,
[CNNL_DTYPE_DOUBLE] = 8,
[CNNL_DTYPE_COMPLEX_HALF] = 4,
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
[CNNL_DTYPE_BFLOAT16] = 2,
};
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
switch (dtype) {
case CNNL_DTYPE_FLOAT:
return torch::kFloat32;
case CNNL_DTYPE_HALF:
return torch::kFloat16;
case CNNL_DTYPE_DOUBLE:
return torch::kFloat64;
case CNNL_DTYPE_INT8:
return torch::kInt8;
case CNNL_DTYPE_INT16:
return torch::kInt16;
case CNNL_DTYPE_INT32:
return torch::kInt32;
case CNNL_DTYPE_INT64:
return torch::kInt64;
case CNNL_DTYPE_UINT8:
return torch::kUInt8;
case CNNL_DTYPE_BOOL:
return torch::kBool;
case CNNL_DTYPE_BFLOAT16:
return torch::kBFloat16;
default:
throw std::runtime_error("Unsupported cnnlDataType_t");
}
}
static float getBandWidth() {
int card = -1;
CNRT_CHECK(cnrtGetDevice(&card));
if (cndevInit(0) != CNDEV_SUCCESS) {
abort();
}
cndevDDRInfo_t ddrinfo;
ddrinfo.version = CNDEV_VERSION_5;
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
abort();
}
double band_width = ddrinfo.bandWidth;
double band_width_decimal = ddrinfo.bandWidthDecimal;
do {
band_width_decimal /= 10;
} while (band_width_decimal > 1);
return float(band_width + band_width_decimal);
}
static void print_info(float time_usec, size_t io_bytes) {
float io_bandwidth = getBandWidth();
std::cout << "kernel time: " << time_usec << "us" << std::endl;
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
}
#define MLU_TENSOR_CHECK_FATAL(tensor) \
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
}
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_UTILS_H_