/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include #include #include #include "rotary_embedding.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { #define NRAM_REMAIN_SIZE (32 * 1024) #define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __nram__ float nram_meta_mask[32] = {1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f}; __nram__ float nram_mask[1024]; __nram__ int nram_offsets[1024]; __mlu_func__ void loadTableAsync(void *nram_table, void *gdram_table, int *nram_offset, int rotary_dim, int rotary_stride, int seq_block, int seq_begin, int dtype_size, bool discrete, bool decoder_mode) { if (!discrete) { int src_stride = decoder_mode ? 0 : rotary_stride * dtype_size; __memcpy_async(nram_table, gdram_table, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, src_stride, seq_block - 1); } else { #if __BANG_ARCH__ >= 592 __gather_async(nram_table, gdram_table, (uint32_t *)nram_offset, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block); #else for (int i = 0; i < seq_block; i++) { __memcpy_async((int8_t *)nram_table + i * rotary_dim * dtype_size, (int8_t *)gdram_table + nram_offset[i], rotary_dim * dtype_size, GDRAM2NRAM); } #endif } } template __mlu_func__ void toFloat(float *dst, T *src, int count) { if (std::is_same::value) { __bang_half2float(dst, (half *)src, count); } else if (std::is_same::value) { __bang_bfloat162float(dst, (bfloat16_t *)src, count); } } template __mlu_func__ void floatTo(T *dst, float *src, int count) { if (std::is_same::value) { __bang_float2half_rn((half *)dst, src, count); } else if (std::is_same::value) { __bang_float2bfloat16_rn((bfloat16_t *)dst, src, count); } } template __mlu_func__ void initMask(float *mask, int rotary_dim, bool interleaved) { if (interleaved) { T *mask0 = (T *)mask; T *mask1 = (T *)(mask + 512); int seg = (rotary_dim + 31) / 32; __memcpy(mask0, nram_meta_mask, 32 * sizeof(float), NRAM2NRAM, 32 * sizeof(float), 0, seg - 1); floatTo((T *)mask0, (float *)mask0, rotary_dim); __bang_add_scalar(mask1, mask0, (T)-1, rotary_dim); } else { __bang_write_value((T *)mask, rotary_dim / 2, (T)-1); __bang_write_value((T *)mask + rotary_dim / 2, rotary_dim / 2, (T)1); } } /* * half: mask, in, sl, sr * float: sl, , sr, , sin, cos */ template __mlu_func__ void crossRotaryEmbedding(T *output, T *input, T *sin_table, T *cos_table, int *seq_offsets, int head_num, int seq_block, int head_size, int rotary_dim, int rotary_stride, size_t input_head_stride, size_t input_seq_stride, size_t output_head_stride, size_t output_seq_stride, int seq_begin, bool discrete, bool decoder_mode = false) { int float_size = sizeof(float); int dtype_size = sizeof(T); int seq_rotary = seq_block * rotary_dim; int block_head = head_num * seq_rotary; float *q_1 = (float *)nram_buffer; float *sincos = q_1 + block_head + 2; float *q_2 = sincos + block_head + 2; T *temp = (T *)q_2 + block_head + 2; if (seq_offsets != nullptr && (discrete || decoder_mode)) { __memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM); __bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block); } bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete; T *mask0 = (T *)nram_mask; T *mask1 = (T *)(nram_mask + 512); T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * (block_head + 2)); T *sincos_ = (T *)((int8_t *)sincos + (float_size - dtype_size) * (block_head + 2)); T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * (block_head + 2)); // if dtype is float, temp point to a new buffer, and temp_ is temp; // if dtype is half/bfloat16, temp is q_2_, and temp_ is (T*)q_2; T *temp_ = dtype_size == 4 ? temp : (T *)q_2; // load input __memcpy_async(q_1_, input, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1, seq_rotary * dtype_size, head_num - 1, input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, head_num - 1); __bang_write_zero(q_2_ + block_head, 2); __sync(); // copy input __memcpy_async(q_2_, q_1_, block_head * dtype_size, NRAM2NRAM); __bang_cycle_mul(temp_, q_1_, mask0, block_head, rotary_dim); __sync(); // load cos loadTableAsync(sincos_, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, dtype_size, gather_table, decoder_mode); __bang_cycle_mul(q_2_, q_2_, mask1, block_head, rotary_dim); // rotary_input __bang_add(q_2_ + 2, temp_, q_2_ + 2, block_head); toFloat(q_1, q_1_, block_head); __sync(); toFloat(sincos, sincos_, block_head); // input * cos __bang_cycle_mul(q_1, q_1, sincos, block_head, seq_rotary); __sync(); toFloat(q_2, q_2_, block_head + 2); // load sin loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, dtype_size, gather_table, decoder_mode); __sync(); toFloat(sincos, sincos_, block_head); // rotary_input * sin __bang_cycle_mul(q_2, q_2 + 1, sincos, block_head, seq_rotary); // input_cos + rotary_input_sin __bang_add(q_1, q_1, q_2, block_head); floatTo((T *)q_1, q_1, block_head); if ((head_size - rotary_dim) > 0) { __memcpy_async(output + rotary_dim, input + rotary_dim, (head_size - rotary_dim) * dtype_size, GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1, output_head_stride * dtype_size, head_num - 1, input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, head_num - 1); } // copy out __memcpy(output, q_1, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1, output_head_stride * dtype_size, head_num - 1, rotary_dim * dtype_size, seq_block - 1, seq_rotary * dtype_size, head_num - 1); } template __mlu_func__ void foldRotaryEmbedding(T *output, T *input, T *sin_table, T *cos_table, int *seq_offsets, int head_num, int seq_block, int head_size, int rotary_dim, int rotary_stride, size_t input_head_stride, size_t input_seq_stride, size_t output_head_stride, size_t output_seq_stride, int seq_begin, bool discrete, bool decoder_mode, bool loop_head, int once_head_num) { once_head_num = loop_head ? once_head_num : head_num; int loop_num = (head_num + once_head_num - 1) / once_head_num; // int head_per_loop = loop_head ? 1 : head_num; int seq_rotary = seq_block * rotary_dim; int block_head = once_head_num * seq_rotary; int buffer_blocks = loop_head ? 2 : 1; float *buffer = (float *)nram_buffer; float *q_2 = buffer + block_head * buffer_blocks; float *sin = q_2 + block_head; float *cos = sin + seq_rotary; int float_size = sizeof(float); int dtype_size = sizeof(T); T *sincos_ = (T *)((int8_t *)sin + (float_size - dtype_size) * seq_rotary * 2); T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * block_head); if (seq_offsets != nullptr && (discrete || decoder_mode)) { __memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM); __bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block); __sync_io_move_compute(); } bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete; int load_head_num = 0; int calc_head_num = 0; int store_head_num = 0; for (int i = 0; i < loop_num + 2; i++) { // store if (i > 1) { store_head_num = std::min(once_head_num, head_num - (i - 2) * once_head_num); if ((head_size - rotary_dim) > 0) { __memcpy_async(output + (i - 2) * once_head_num * output_head_stride + rotary_dim, input + (i - 2) * once_head_num * input_head_stride + rotary_dim, (head_size - rotary_dim) * dtype_size, GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1, output_head_stride * dtype_size, store_head_num - 1, input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, store_head_num - 1); } float *nram_store = buffer + (i % 2) * block_head; __memcpy_async(output + (i - 2) * once_head_num * output_head_stride, nram_store, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1, output_head_stride * dtype_size, store_head_num - 1, rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size, store_head_num - 1); } // load float *temp_load = buffer + (i % 2) * block_head; T *nram_load = (T *)((int8_t *)temp_load + (float_size - dtype_size) * block_head); if (i < loop_num) { load_head_num = std::min(once_head_num, head_num - i * once_head_num); __memcpy_async(nram_load, input + i * once_head_num * input_head_stride, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size, load_head_num - 1, input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size, load_head_num - 1); } if (i == 1) { loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, dtype_size, gather_table, decoder_mode); loadTableAsync(sincos_ + seq_rotary, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin, dtype_size, gather_table, decoder_mode); } // compute if (i > 0 && i < loop_num + 1) { float *q_1 = buffer + ((i + 1) % 2) * block_head; T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * block_head); calc_head_num = std::min(once_head_num, head_num - (i - 1) * once_head_num); __memcpy_async(q_2_, q_1_ + rotary_dim / 2, rotary_dim / 2 * dtype_size, NRAM2NRAM, rotary_dim * dtype_size, rotary_dim * dtype_size, calc_head_num * seq_block - 1); __memcpy_async(q_2_ + rotary_dim / 2, q_1_, rotary_dim / 2 * dtype_size, NRAM2NRAM, rotary_dim * dtype_size, rotary_dim * dtype_size, calc_head_num * seq_block - 1); __sync_move(); toFloat(q_1, q_1_, block_head); __bang_cycle_mul(q_2_, q_2_, (T *)nram_mask, block_head, rotary_dim); toFloat(q_2, q_2_, block_head); if (i == 1) { __sync_io(); toFloat(sin, sincos_, seq_rotary * 2); } __bang_cycle_mul(q_1, q_1, cos, block_head, seq_rotary); __bang_cycle_mul(q_2, q_2, sin, block_head, seq_rotary); __bang_add(q_1, q_1, q_2, block_head); floatTo((T *)q_1, q_1, block_head); } __sync_io_move_compute(); } } // [bs, seq_block] template __mlu_global__ void MluRotaryEmebdding(void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_dim, int seq_once, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool discrete, bool dynamic_ntk, bool decoder_mode, bool loop_head, int once_head_num) { initMask(nram_mask, rotary_dim, interleaved); int head_begin = taskIdX; int head_per_task = taskDimX == 1 ? head_num : 1; // decode mode little diff: no loop if (decoder_mode) { int task_begin_seq = taskIdY * seq_once; int seq_block = std::min(batch - task_begin_seq, seq_once); if (seq_block <= 0 || __is_mpu()) { return; } size_t input_offset = task_begin_seq * input_seq_stride + head_begin * input_head_stride; size_t output_offset = task_begin_seq * output_seq_stride + head_begin * output_head_stride; T *input_begin = (T *)input + input_offset; T *output_begin = (T *)output + output_offset; if (interleaved) { crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table, (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride, output_head_stride, output_seq_stride, task_begin_seq, discrete, decoder_mode); } else { foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table, (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride, output_head_stride, output_seq_stride, task_begin_seq, discrete, decoder_mode, loop_head, once_head_num); } return; } int seq_begin = cu_seq_lens == nullptr ? taskIdY * max_seq_len : cu_seq_lens[taskIdY]; int seq_len = cu_seq_lens == nullptr ? max_seq_len : cu_seq_lens[taskIdY + 1] - seq_begin; for (int i = taskIdZ * seq_once; i < seq_len; i += taskDimZ * seq_once) { int seq_block = std::min(seq_once, seq_len - i); int global_seq_begin = seq_begin + i; int seq_block_begin = i; size_t input_offset = global_seq_begin * input_seq_stride + head_begin * input_head_stride; size_t output_offset = global_seq_begin * output_seq_stride + head_begin * output_head_stride; size_t bs_table_offset = dynamic_ntk ? (size_t)taskIdY * rotary_seq_len * rotary_stride : 0; T *input_begin = (T *)input + input_offset; T *output_begin = (T *)output + output_offset; T *sin_table_begin = (T *)sin_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride; T *cos_table_begin = (T *)cos_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride; if (seq_offsets != nullptr && !discrete) { sin_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride; cos_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride; } else if (seq_offsets != nullptr && discrete) { sin_table_begin = (T *)sin_table + bs_table_offset; cos_table_begin = (T *)cos_table + bs_table_offset; } if (interleaved) { crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin, (T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride, output_head_stride, output_seq_stride, global_seq_begin, discrete, decoder_mode); __sync_io_move_compute(); } else { foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin, (T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride, output_head_stride, output_seq_stride, global_seq_begin, discrete, decoder_mode, loop_head, once_head_num); } } } #if __BANG_ARCH__ < 592 template <> __mlu_global__ void MluRotaryEmebdding(void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_dim, int seq_once, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool discrete, bool dynamic_ntk, bool decoder_mode, bool loop_head, int once_head_num) {} template <> __mlu_global__ void MluRotaryEmebdding(void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_dim, int seq_once, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool discrete, bool dynamic_ntk, bool decoder_mode, bool loop_head, int once_head_num) {} #endif } // namespace kernels KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue, void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_dim, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool interleaved, bool discrete, bool dynamic_ntk, cnnlDataType_t data_type) { void (*rotary_embedding_kernels[])(void *, /* output */ const void *, /* input */ const void *, /* sin_table */ const void *, /* cos_table */ const int *, /* seq_offsets */ const int *, /* cu_seq_lens */ int, /* batch */ int, /* max_seq_len */ int, /* head_num */ int, /* head_size */ int, /* rotary_seq_len */ int, /* rotary_dim */ int, /* seq_once */ int, /* rotary_stride */ size_t, /* input_seq_stride */ size_t, /* input_head_stride */ size_t, /* output_seq_stride */ size_t, /* output_head_stride */ bool, /* discrete, */ bool, /* dynamic_ntk */ bool, /* decoder_mode */ bool, /* loop_head */ int) /* once_head_num */ = {kernels::MluRotaryEmebdding, kernels::MluRotaryEmebdding, kernels::MluRotaryEmebdding, kernels::MluRotaryEmebdding, kernels::MluRotaryEmebdding, kernels::MluRotaryEmebdding}; int kernel_index = 0; if (data_type == CNNL_DTYPE_HALF) { kernel_index = interleaved ? 0 : 1; } else if (data_type == CNNL_DTYPE_BFLOAT16) { kernel_index = interleaved ? 2 : 3; } else if (data_type == CNNL_DTYPE_FLOAT) { kernel_index = interleaved ? 4 : 5; } if (head_size > 256) { std::cerr << "[invokeRotaryEmbedding]: only supported head_size <= 256, currently head_size = " << head_size << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } CNdev dev; cnCtxGetDevice(&dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); int total_core_num = cluster_num * core_num; uint32_t seq_once = data_type == CNNL_DTYPE_FLOAT ? (rotary_dim > 128 ? 64 : 128) : (rotary_dim > 128 ? 128 : 256); // decode场景,需要判断空间是否够,fold场景下最大限制为每个ipu处理64,cross限制为batch*head小于等于sq_once int batch_per_core = (batch + total_core_num - 1) / total_core_num; int batch_per_core_cap = 64; bool batch_limit = interleaved ? (batch_per_core * head_num <= seq_once) : (batch_per_core <= batch_per_core_cap); bool decoder_mode = batch_limit && max_seq_len == 1 && dynamic_ntk == false; bool do_one_head_per_task = (head_num > 32 && max_seq_len > 2048) || head_num > seq_once; seq_once = do_one_head_per_task ? seq_once : seq_once / head_num; // fold rotary做了流水,拆分有所不同。 bool loop_head = true; int once_head_num = 1; if (!interleaved) { seq_once = rotary_dim > 128 ? 64 : 128; // 小seq情况下,不够拆,需要减小seq_once if (batch * (max_seq_len + seq_once - 1) / seq_once < total_core_num) { seq_once = std::max(1, max_seq_len / (total_core_num / batch)); } do_one_head_per_task = false; // 判断decode场景能否一次性处理完所有head if (decoder_mode) { loop_head = false; int nram_buffer_size = 480 * 1024; int nram_input_size = batch_per_core * head_num * rotary_dim * sizeof(float); int nram_q2_size = batch_per_core * head_num * rotary_dim * sizeof(float); int nram_table_size = batch_per_core * rotary_dim * sizeof(float) * 2; int total_nram_size = nram_input_size + nram_q2_size + nram_table_size; loop_head = total_nram_size > nram_buffer_size; if (loop_head) { // 如果需要循环,则重新计算每次处理多少头 once_head_num = (nram_buffer_size - nram_table_size) / (batch_per_core * rotary_dim * sizeof(float) * 3); } // rebalance int loop_num = (head_num + once_head_num - 1) / once_head_num; once_head_num = (head_num + loop_num - 1) / loop_num; } } uint32_t seq_segments = ((uint32_t)max_seq_len + seq_once - 1) / seq_once; uint32_t task_dimx = do_one_head_per_task ? head_num : 1; uint32_t task_dimz = total_core_num > seq_segments ? seq_segments : total_core_num; uint32_t task_dimy = decoder_mode && !do_one_head_per_task ? (uint32_t)total_core_num : (uint32_t)batch; seq_once = decoder_mode ? (batch + task_dimy - 1) / task_dimy : seq_once; cnrtDim3_t dim = {task_dimx, task_dimy, task_dimz}; if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) { std::cerr << "[invokeRotaryEmbedding]: MLU300 devices do not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } rotary_embedding_kernels[kernel_index]<<>>( output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num, head_size, rotary_seq_len, rotary_dim, seq_once, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, output_head_stride, discrete, dynamic_ntk, decoder_mode, loop_head, once_head_num); return KernelStatus::KERNEL_STATUS_SUCCESS; } KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue, void *output, const void *input, const void *sin_table, const void *cos_table, const int *seq_offsets, const int *cu_seq_lens, int batch, int max_seq_len, int total_seq_len, int head_num, int head_size, int rotary_seq_len, int rotary_stride, size_t input_seq_stride, size_t input_head_stride, size_t output_seq_stride, size_t output_head_stride, bool interleaved, cnnlDataType_t data_type) { size_t type_size = 0; cnnlGetSizeOfDataType(data_type, &type_size); invokeRotaryEmbedding(queue, output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, output_head_stride, interleaved, true, false, data_type); invokeRotaryEmbedding(queue, (int8_t *)output + head_size / 2 * type_size, (int8_t *)input + head_size / 2 * type_size, sin_table, cos_table, seq_offsets + total_seq_len, cu_seq_lens, batch, max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2, rotary_stride, input_seq_stride, input_head_stride, output_seq_stride, output_head_stride, interleaved, true, false, data_type); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo