Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/rotary_embedding.mlu
2026-02-04 17:39:32 +08:00

629 lines
31 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cstddef>
#include <type_traits>
#include "rotary_embedding.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_meta_mask[32] = {1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f,
0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f,
1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f};
__nram__ float nram_mask[1024];
__nram__ int nram_offsets[1024];
__mlu_func__ void loadTableAsync(void *nram_table,
void *gdram_table,
int *nram_offset,
int rotary_dim,
int rotary_stride,
int seq_block,
int seq_begin,
int dtype_size,
bool discrete,
bool decoder_mode) {
if (!discrete) {
int src_stride = decoder_mode ? 0 : rotary_stride * dtype_size;
__memcpy_async(nram_table, gdram_table, rotary_dim * dtype_size, GDRAM2NRAM,
rotary_dim * dtype_size, src_stride, seq_block - 1);
} else {
#if __BANG_ARCH__ >= 592
__gather_async(nram_table, gdram_table, (uint32_t *)nram_offset, rotary_dim * dtype_size,
GDRAM2NRAM, rotary_dim * dtype_size, seq_block);
#else
for (int i = 0; i < seq_block; i++) {
__memcpy_async((int8_t *)nram_table + i * rotary_dim * dtype_size,
(int8_t *)gdram_table + nram_offset[i], rotary_dim * dtype_size, GDRAM2NRAM);
}
#endif
}
}
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
}
}
template <typename T>
__mlu_func__ void initMask(float *mask, int rotary_dim, bool interleaved) {
if (interleaved) {
T *mask0 = (T *)mask;
T *mask1 = (T *)(mask + 512);
int seg = (rotary_dim + 31) / 32;
__memcpy(mask0, nram_meta_mask, 32 * sizeof(float), NRAM2NRAM, 32 * sizeof(float), 0, seg - 1);
floatTo((T *)mask0, (float *)mask0, rotary_dim);
__bang_add_scalar(mask1, mask0, (T)-1, rotary_dim);
} else {
__bang_write_value((T *)mask, rotary_dim / 2, (T)-1);
__bang_write_value((T *)mask + rotary_dim / 2, rotary_dim / 2, (T)1);
}
}
/*
* half: mask, in, sl, sr
* float: sl, , sr, , sin, cos
*/
template <typename T>
__mlu_func__ void crossRotaryEmbedding(T *output,
T *input,
T *sin_table,
T *cos_table,
int *seq_offsets,
int head_num,
int seq_block,
int head_size,
int rotary_dim,
int rotary_stride,
size_t input_head_stride,
size_t input_seq_stride,
size_t output_head_stride,
size_t output_seq_stride,
int seq_begin,
bool discrete,
bool decoder_mode = false) {
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int seq_rotary = seq_block * rotary_dim;
int block_head = head_num * seq_rotary;
float *q_1 = (float *)nram_buffer;
float *sincos = q_1 + block_head + 2;
float *q_2 = sincos + block_head + 2;
T *temp = (T *)q_2 + block_head + 2;
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
}
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
T *mask0 = (T *)nram_mask;
T *mask1 = (T *)(nram_mask + 512);
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * (block_head + 2));
T *sincos_ = (T *)((int8_t *)sincos + (float_size - dtype_size) * (block_head + 2));
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * (block_head + 2));
// if dtype is float, temp point to a new buffer, and temp_ is temp;
// if dtype is half/bfloat16, temp is q_2_, and temp_ is (T*)q_2;
T *temp_ = dtype_size == 4 ? temp : (T *)q_2;
// load input
__memcpy_async(q_1_, input, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size,
seq_block - 1, seq_rotary * dtype_size, head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
head_num - 1);
__bang_write_zero(q_2_ + block_head, 2);
__sync();
// copy input
__memcpy_async(q_2_, q_1_, block_head * dtype_size, NRAM2NRAM);
__bang_cycle_mul(temp_, q_1_, mask0, block_head, rotary_dim);
__sync();
// load cos
loadTableAsync(sincos_, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
dtype_size, gather_table, decoder_mode);
__bang_cycle_mul(q_2_, q_2_, mask1, block_head, rotary_dim);
// rotary_input
__bang_add(q_2_ + 2, temp_, q_2_ + 2, block_head);
toFloat(q_1, q_1_, block_head);
__sync();
toFloat(sincos, sincos_, block_head);
// input * cos
__bang_cycle_mul(q_1, q_1, sincos, block_head, seq_rotary);
__sync();
toFloat(q_2, q_2_, block_head + 2);
// load sin
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
dtype_size, gather_table, decoder_mode);
__sync();
toFloat(sincos, sincos_, block_head);
// rotary_input * sin
__bang_cycle_mul(q_2, q_2 + 1, sincos, block_head, seq_rotary);
// input_cos + rotary_input_sin
__bang_add(q_1, q_1, q_2, block_head);
floatTo((T *)q_1, q_1, block_head);
if ((head_size - rotary_dim) > 0) {
__memcpy_async(output + rotary_dim, input + rotary_dim, (head_size - rotary_dim) * dtype_size,
GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1,
output_head_stride * dtype_size, head_num - 1, input_seq_stride * dtype_size,
seq_block - 1, input_head_stride * dtype_size, head_num - 1);
}
// copy out
__memcpy(output, q_1, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
seq_block - 1, output_head_stride * dtype_size, head_num - 1, rotary_dim * dtype_size,
seq_block - 1, seq_rotary * dtype_size, head_num - 1);
}
template <typename T>
__mlu_func__ void foldRotaryEmbedding(T *output,
T *input,
T *sin_table,
T *cos_table,
int *seq_offsets,
int head_num,
int seq_block,
int head_size,
int rotary_dim,
int rotary_stride,
size_t input_head_stride,
size_t input_seq_stride,
size_t output_head_stride,
size_t output_seq_stride,
int seq_begin,
bool discrete,
bool decoder_mode,
bool loop_head,
int once_head_num) {
once_head_num = loop_head ? once_head_num : head_num;
int loop_num = (head_num + once_head_num - 1) / once_head_num;
// int head_per_loop = loop_head ? 1 : head_num;
int seq_rotary = seq_block * rotary_dim;
int block_head = once_head_num * seq_rotary;
int buffer_blocks = loop_head ? 2 : 1;
float *buffer = (float *)nram_buffer;
float *q_2 = buffer + block_head * buffer_blocks;
float *sin = q_2 + block_head;
float *cos = sin + seq_rotary;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
T *sincos_ = (T *)((int8_t *)sin + (float_size - dtype_size) * seq_rotary * 2);
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * block_head);
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
__sync_io_move_compute();
}
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
int load_head_num = 0;
int calc_head_num = 0;
int store_head_num = 0;
for (int i = 0; i < loop_num + 2; i++) {
// store
if (i > 1) {
store_head_num = std::min(once_head_num, head_num - (i - 2) * once_head_num);
if ((head_size - rotary_dim) > 0) {
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride + rotary_dim,
input + (i - 2) * once_head_num * input_head_stride + rotary_dim,
(head_size - rotary_dim) * dtype_size, GDRAM2GDRAM,
output_seq_stride * dtype_size, seq_block - 1,
output_head_stride * dtype_size, store_head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
store_head_num - 1);
}
float *nram_store = buffer + (i % 2) * block_head;
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride, nram_store,
rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
seq_block - 1, output_head_stride * dtype_size, store_head_num - 1,
rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size,
store_head_num - 1);
}
// load
float *temp_load = buffer + (i % 2) * block_head;
T *nram_load = (T *)((int8_t *)temp_load + (float_size - dtype_size) * block_head);
if (i < loop_num) {
load_head_num = std::min(once_head_num, head_num - i * once_head_num);
__memcpy_async(nram_load, input + i * once_head_num * input_head_stride,
rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1,
seq_block * rotary_dim * dtype_size, load_head_num - 1,
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
load_head_num - 1);
}
if (i == 1) {
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block,
seq_begin, dtype_size, gather_table, decoder_mode);
loadTableAsync(sincos_ + seq_rotary, cos_table, nram_offsets, rotary_dim, rotary_stride,
seq_block, seq_begin, dtype_size, gather_table, decoder_mode);
}
// compute
if (i > 0 && i < loop_num + 1) {
float *q_1 = buffer + ((i + 1) % 2) * block_head;
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * block_head);
calc_head_num = std::min(once_head_num, head_num - (i - 1) * once_head_num);
__memcpy_async(q_2_, q_1_ + rotary_dim / 2, rotary_dim / 2 * dtype_size, NRAM2NRAM,
rotary_dim * dtype_size, rotary_dim * dtype_size,
calc_head_num * seq_block - 1);
__memcpy_async(q_2_ + rotary_dim / 2, q_1_, rotary_dim / 2 * dtype_size, NRAM2NRAM,
rotary_dim * dtype_size, rotary_dim * dtype_size,
calc_head_num * seq_block - 1);
__sync_move();
toFloat(q_1, q_1_, block_head);
__bang_cycle_mul(q_2_, q_2_, (T *)nram_mask, block_head, rotary_dim);
toFloat(q_2, q_2_, block_head);
if (i == 1) {
__sync_io();
toFloat(sin, sincos_, seq_rotary * 2);
}
__bang_cycle_mul(q_1, q_1, cos, block_head, seq_rotary);
__bang_cycle_mul(q_2, q_2, sin, block_head, seq_rotary);
__bang_add(q_1, q_1, q_2, block_head);
floatTo((T *)q_1, q_1, block_head);
}
__sync_io_move_compute();
}
}
// [bs, seq_block]
template <typename T, bool interleaved>
__mlu_global__ void MluRotaryEmebdding(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {
initMask<T>(nram_mask, rotary_dim, interleaved);
int head_begin = taskIdX;
int head_per_task = taskDimX == 1 ? head_num : 1;
// decode mode little diff: no loop
if (decoder_mode) {
int task_begin_seq = taskIdY * seq_once;
int seq_block = std::min(batch - task_begin_seq, seq_once);
if (seq_block <= 0 || __is_mpu()) {
return;
}
size_t input_offset = task_begin_seq * input_seq_stride + head_begin * input_head_stride;
size_t output_offset = task_begin_seq * output_seq_stride + head_begin * output_head_stride;
T *input_begin = (T *)input + input_offset;
T *output_begin = (T *)output + output_offset;
if (interleaved) {
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
output_seq_stride, task_begin_seq, discrete, decoder_mode);
} else {
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
output_seq_stride, task_begin_seq, discrete, decoder_mode, loop_head,
once_head_num);
}
return;
}
int seq_begin = cu_seq_lens == nullptr ? taskIdY * max_seq_len : cu_seq_lens[taskIdY];
int seq_len = cu_seq_lens == nullptr ? max_seq_len : cu_seq_lens[taskIdY + 1] - seq_begin;
for (int i = taskIdZ * seq_once; i < seq_len; i += taskDimZ * seq_once) {
int seq_block = std::min(seq_once, seq_len - i);
int global_seq_begin = seq_begin + i;
int seq_block_begin = i;
size_t input_offset = global_seq_begin * input_seq_stride + head_begin * input_head_stride;
size_t output_offset = global_seq_begin * output_seq_stride + head_begin * output_head_stride;
size_t bs_table_offset = dynamic_ntk ? (size_t)taskIdY * rotary_seq_len * rotary_stride : 0;
T *input_begin = (T *)input + input_offset;
T *output_begin = (T *)output + output_offset;
T *sin_table_begin = (T *)sin_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
T *cos_table_begin = (T *)cos_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
if (seq_offsets != nullptr && !discrete) {
sin_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
cos_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
} else if (seq_offsets != nullptr && discrete) {
sin_table_begin = (T *)sin_table + bs_table_offset;
cos_table_begin = (T *)cos_table + bs_table_offset;
}
if (interleaved) {
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
head_size, rotary_dim, rotary_stride, input_head_stride,
input_seq_stride, output_head_stride, output_seq_stride,
global_seq_begin, discrete, decoder_mode);
__sync_io_move_compute();
} else {
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride,
output_head_stride, output_seq_stride, global_seq_begin, discrete,
decoder_mode, loop_head, once_head_num);
}
}
}
#if __BANG_ARCH__ < 592
template <>
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, true>(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {}
template <>
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, false>(void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int seq_once,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool discrete,
bool dynamic_ntk,
bool decoder_mode,
bool loop_head,
int once_head_num) {}
#endif
} // namespace kernels
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_dim,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
bool discrete,
bool dynamic_ntk,
cnnlDataType_t data_type) {
void (*rotary_embedding_kernels[])(void *, /* output */
const void *, /* input */
const void *, /* sin_table */
const void *, /* cos_table */
const int *, /* seq_offsets */
const int *, /* cu_seq_lens */
int, /* batch */
int, /* max_seq_len */
int, /* head_num */
int, /* head_size */
int, /* rotary_seq_len */
int, /* rotary_dim */
int, /* seq_once */
int, /* rotary_stride */
size_t, /* input_seq_stride */
size_t, /* input_head_stride */
size_t, /* output_seq_stride */
size_t, /* output_head_stride */
bool, /* discrete, */
bool, /* dynamic_ntk */
bool, /* decoder_mode */
bool, /* loop_head */
int) /* once_head_num */
= {kernels::MluRotaryEmebdding<half, true>,
kernels::MluRotaryEmebdding<half, false>,
kernels::MluRotaryEmebdding<bfloat16_t, true>,
kernels::MluRotaryEmebdding<bfloat16_t, false>,
kernels::MluRotaryEmebdding<float, true>,
kernels::MluRotaryEmebdding<float, false>};
int kernel_index = 0;
if (data_type == CNNL_DTYPE_HALF) {
kernel_index = interleaved ? 0 : 1;
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
kernel_index = interleaved ? 2 : 3;
} else if (data_type == CNNL_DTYPE_FLOAT) {
kernel_index = interleaved ? 4 : 5;
}
if (head_size > 256) {
std::cerr << "[invokeRotaryEmbedding]: only supported head_size <= 256, currently head_size = "
<< head_size << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int total_core_num = cluster_num * core_num;
uint32_t seq_once = data_type == CNNL_DTYPE_FLOAT ? (rotary_dim > 128 ? 64 : 128)
: (rotary_dim > 128 ? 128 : 256);
// decode场景需要判断空间是否够fold场景下最大限制为每个ipu处理64cross限制为batch*head小于等于sq_once
int batch_per_core = (batch + total_core_num - 1) / total_core_num;
int batch_per_core_cap = 64;
bool batch_limit = interleaved ? (batch_per_core * head_num <= seq_once)
: (batch_per_core <= batch_per_core_cap);
bool decoder_mode = batch_limit && max_seq_len == 1 && dynamic_ntk == false;
bool do_one_head_per_task = (head_num > 32 && max_seq_len > 2048) || head_num > seq_once;
seq_once = do_one_head_per_task ? seq_once : seq_once / head_num;
// fold rotary做了流水拆分有所不同。
bool loop_head = true;
int once_head_num = 1;
if (!interleaved) {
seq_once = rotary_dim > 128 ? 64 : 128;
// 小seq情况下不够拆需要减小seq_once
if (batch * (max_seq_len + seq_once - 1) / seq_once < total_core_num) {
seq_once = std::max(1, max_seq_len / (total_core_num / batch));
}
do_one_head_per_task = false;
// 判断decode场景能否一次性处理完所有head
if (decoder_mode) {
loop_head = false;
int nram_buffer_size = 480 * 1024;
int nram_input_size = batch_per_core * head_num * rotary_dim * sizeof(float);
int nram_q2_size = batch_per_core * head_num * rotary_dim * sizeof(float);
int nram_table_size = batch_per_core * rotary_dim * sizeof(float) * 2;
int total_nram_size = nram_input_size + nram_q2_size + nram_table_size;
loop_head = total_nram_size > nram_buffer_size;
if (loop_head) {
// 如果需要循环,则重新计算每次处理多少头
once_head_num = (nram_buffer_size - nram_table_size) /
(batch_per_core * rotary_dim * sizeof(float) * 3);
}
// rebalance
int loop_num = (head_num + once_head_num - 1) / once_head_num;
once_head_num = (head_num + loop_num - 1) / loop_num;
}
}
uint32_t seq_segments = ((uint32_t)max_seq_len + seq_once - 1) / seq_once;
uint32_t task_dimx = do_one_head_per_task ? head_num : 1;
uint32_t task_dimz = total_core_num > seq_segments ? seq_segments : total_core_num;
uint32_t task_dimy =
decoder_mode && !do_one_head_per_task ? (uint32_t)total_core_num : (uint32_t)batch;
seq_once = decoder_mode ? (batch + task_dimy - 1) / task_dimy : seq_once;
cnrtDim3_t dim = {task_dimx, task_dimy, task_dimz};
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
std::cerr << "[invokeRotaryEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
rotary_embedding_kernels[kernel_index]<<<dim, cnrtFuncTypeBlock, queue>>>(
output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num,
head_size, rotary_seq_len, rotary_dim, seq_once, rotary_stride, input_seq_stride,
input_head_stride, output_seq_stride, output_head_stride, discrete, dynamic_ntk, decoder_mode,
loop_head, once_head_num);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
void *output,
const void *input,
const void *sin_table,
const void *cos_table,
const int *seq_offsets,
const int *cu_seq_lens,
int batch,
int max_seq_len,
int total_seq_len,
int head_num,
int head_size,
int rotary_seq_len,
int rotary_stride,
size_t input_seq_stride,
size_t input_head_stride,
size_t output_seq_stride,
size_t output_head_stride,
bool interleaved,
cnnlDataType_t data_type) {
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
invokeRotaryEmbedding(queue, output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch,
max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2,
rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
output_head_stride, interleaved, true, false, data_type);
invokeRotaryEmbedding(queue, (int8_t *)output + head_size / 2 * type_size,
(int8_t *)input + head_size / 2 * type_size, sin_table, cos_table,
seq_offsets + total_seq_len, cu_seq_lens, batch, max_seq_len, head_num,
head_size / 2, rotary_seq_len, head_size / 2, rotary_stride,
input_seq_stride, input_head_stride, output_seq_stride, output_head_stride,
interleaved, true, false, data_type);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo