629 lines
31 KiB
Plaintext
629 lines
31 KiB
Plaintext
/*************************************************************************
|
||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||
*
|
||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||
*************************************************************************/
|
||
#include <algorithm>
|
||
#include <cstddef>
|
||
#include <type_traits>
|
||
#include "rotary_embedding.mluh"
|
||
// clang-format off
|
||
#include <mlu.h>
|
||
// clang-format on
|
||
|
||
namespace tmo {
|
||
namespace kernels {
|
||
#define NRAM_REMAIN_SIZE (32 * 1024)
|
||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
||
|
||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||
__nram__ float nram_meta_mask[32] = {1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f,
|
||
0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f,
|
||
1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f, 1.f, 0.f};
|
||
__nram__ float nram_mask[1024];
|
||
__nram__ int nram_offsets[1024];
|
||
|
||
__mlu_func__ void loadTableAsync(void *nram_table,
|
||
void *gdram_table,
|
||
int *nram_offset,
|
||
int rotary_dim,
|
||
int rotary_stride,
|
||
int seq_block,
|
||
int seq_begin,
|
||
int dtype_size,
|
||
bool discrete,
|
||
bool decoder_mode) {
|
||
if (!discrete) {
|
||
int src_stride = decoder_mode ? 0 : rotary_stride * dtype_size;
|
||
__memcpy_async(nram_table, gdram_table, rotary_dim * dtype_size, GDRAM2NRAM,
|
||
rotary_dim * dtype_size, src_stride, seq_block - 1);
|
||
} else {
|
||
#if __BANG_ARCH__ >= 592
|
||
__gather_async(nram_table, gdram_table, (uint32_t *)nram_offset, rotary_dim * dtype_size,
|
||
GDRAM2NRAM, rotary_dim * dtype_size, seq_block);
|
||
#else
|
||
for (int i = 0; i < seq_block; i++) {
|
||
__memcpy_async((int8_t *)nram_table + i * rotary_dim * dtype_size,
|
||
(int8_t *)gdram_table + nram_offset[i], rotary_dim * dtype_size, GDRAM2NRAM);
|
||
}
|
||
#endif
|
||
}
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void toFloat(float *dst, T *src, int count) {
|
||
if (std::is_same<T, half>::value) {
|
||
__bang_half2float(dst, (half *)src, count);
|
||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
|
||
}
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void floatTo(T *dst, float *src, int count) {
|
||
if (std::is_same<T, half>::value) {
|
||
__bang_float2half_rn((half *)dst, src, count);
|
||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
|
||
}
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void initMask(float *mask, int rotary_dim, bool interleaved) {
|
||
if (interleaved) {
|
||
T *mask0 = (T *)mask;
|
||
T *mask1 = (T *)(mask + 512);
|
||
int seg = (rotary_dim + 31) / 32;
|
||
__memcpy(mask0, nram_meta_mask, 32 * sizeof(float), NRAM2NRAM, 32 * sizeof(float), 0, seg - 1);
|
||
floatTo((T *)mask0, (float *)mask0, rotary_dim);
|
||
__bang_add_scalar(mask1, mask0, (T)-1, rotary_dim);
|
||
} else {
|
||
__bang_write_value((T *)mask, rotary_dim / 2, (T)-1);
|
||
__bang_write_value((T *)mask + rotary_dim / 2, rotary_dim / 2, (T)1);
|
||
}
|
||
}
|
||
|
||
/*
|
||
* half: mask, in, sl, sr
|
||
* float: sl, , sr, , sin, cos
|
||
*/
|
||
template <typename T>
|
||
__mlu_func__ void crossRotaryEmbedding(T *output,
|
||
T *input,
|
||
T *sin_table,
|
||
T *cos_table,
|
||
int *seq_offsets,
|
||
int head_num,
|
||
int seq_block,
|
||
int head_size,
|
||
int rotary_dim,
|
||
int rotary_stride,
|
||
size_t input_head_stride,
|
||
size_t input_seq_stride,
|
||
size_t output_head_stride,
|
||
size_t output_seq_stride,
|
||
int seq_begin,
|
||
bool discrete,
|
||
bool decoder_mode = false) {
|
||
int float_size = sizeof(float);
|
||
int dtype_size = sizeof(T);
|
||
int seq_rotary = seq_block * rotary_dim;
|
||
int block_head = head_num * seq_rotary;
|
||
float *q_1 = (float *)nram_buffer;
|
||
float *sincos = q_1 + block_head + 2;
|
||
float *q_2 = sincos + block_head + 2;
|
||
T *temp = (T *)q_2 + block_head + 2;
|
||
|
||
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
|
||
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
|
||
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
|
||
}
|
||
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
|
||
|
||
T *mask0 = (T *)nram_mask;
|
||
T *mask1 = (T *)(nram_mask + 512);
|
||
|
||
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * (block_head + 2));
|
||
T *sincos_ = (T *)((int8_t *)sincos + (float_size - dtype_size) * (block_head + 2));
|
||
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * (block_head + 2));
|
||
// if dtype is float, temp point to a new buffer, and temp_ is temp;
|
||
// if dtype is half/bfloat16, temp is q_2_, and temp_ is (T*)q_2;
|
||
T *temp_ = dtype_size == 4 ? temp : (T *)q_2;
|
||
|
||
// load input
|
||
__memcpy_async(q_1_, input, rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size,
|
||
seq_block - 1, seq_rotary * dtype_size, head_num - 1,
|
||
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
|
||
head_num - 1);
|
||
__bang_write_zero(q_2_ + block_head, 2);
|
||
__sync();
|
||
|
||
// copy input
|
||
__memcpy_async(q_2_, q_1_, block_head * dtype_size, NRAM2NRAM);
|
||
__bang_cycle_mul(temp_, q_1_, mask0, block_head, rotary_dim);
|
||
__sync();
|
||
|
||
// load cos
|
||
loadTableAsync(sincos_, cos_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
|
||
dtype_size, gather_table, decoder_mode);
|
||
|
||
__bang_cycle_mul(q_2_, q_2_, mask1, block_head, rotary_dim);
|
||
|
||
// rotary_input
|
||
__bang_add(q_2_ + 2, temp_, q_2_ + 2, block_head);
|
||
|
||
toFloat(q_1, q_1_, block_head);
|
||
__sync();
|
||
|
||
toFloat(sincos, sincos_, block_head);
|
||
|
||
// input * cos
|
||
__bang_cycle_mul(q_1, q_1, sincos, block_head, seq_rotary);
|
||
__sync();
|
||
|
||
toFloat(q_2, q_2_, block_head + 2);
|
||
|
||
// load sin
|
||
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block, seq_begin,
|
||
dtype_size, gather_table, decoder_mode);
|
||
__sync();
|
||
|
||
toFloat(sincos, sincos_, block_head);
|
||
|
||
// rotary_input * sin
|
||
__bang_cycle_mul(q_2, q_2 + 1, sincos, block_head, seq_rotary);
|
||
|
||
// input_cos + rotary_input_sin
|
||
__bang_add(q_1, q_1, q_2, block_head);
|
||
|
||
floatTo((T *)q_1, q_1, block_head);
|
||
|
||
if ((head_size - rotary_dim) > 0) {
|
||
__memcpy_async(output + rotary_dim, input + rotary_dim, (head_size - rotary_dim) * dtype_size,
|
||
GDRAM2GDRAM, output_seq_stride * dtype_size, seq_block - 1,
|
||
output_head_stride * dtype_size, head_num - 1, input_seq_stride * dtype_size,
|
||
seq_block - 1, input_head_stride * dtype_size, head_num - 1);
|
||
}
|
||
|
||
// copy out
|
||
__memcpy(output, q_1, rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
|
||
seq_block - 1, output_head_stride * dtype_size, head_num - 1, rotary_dim * dtype_size,
|
||
seq_block - 1, seq_rotary * dtype_size, head_num - 1);
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void foldRotaryEmbedding(T *output,
|
||
T *input,
|
||
T *sin_table,
|
||
T *cos_table,
|
||
int *seq_offsets,
|
||
int head_num,
|
||
int seq_block,
|
||
int head_size,
|
||
int rotary_dim,
|
||
int rotary_stride,
|
||
size_t input_head_stride,
|
||
size_t input_seq_stride,
|
||
size_t output_head_stride,
|
||
size_t output_seq_stride,
|
||
int seq_begin,
|
||
bool discrete,
|
||
bool decoder_mode,
|
||
bool loop_head,
|
||
int once_head_num) {
|
||
once_head_num = loop_head ? once_head_num : head_num;
|
||
int loop_num = (head_num + once_head_num - 1) / once_head_num;
|
||
// int head_per_loop = loop_head ? 1 : head_num;
|
||
int seq_rotary = seq_block * rotary_dim;
|
||
int block_head = once_head_num * seq_rotary;
|
||
int buffer_blocks = loop_head ? 2 : 1;
|
||
|
||
float *buffer = (float *)nram_buffer;
|
||
float *q_2 = buffer + block_head * buffer_blocks;
|
||
float *sin = q_2 + block_head;
|
||
float *cos = sin + seq_rotary;
|
||
|
||
int float_size = sizeof(float);
|
||
int dtype_size = sizeof(T);
|
||
T *sincos_ = (T *)((int8_t *)sin + (float_size - dtype_size) * seq_rotary * 2);
|
||
T *q_2_ = (T *)((int8_t *)q_2 + (float_size - dtype_size) * block_head);
|
||
if (seq_offsets != nullptr && (discrete || decoder_mode)) {
|
||
__memcpy(nram_offsets, seq_offsets + seq_begin, seq_block * sizeof(int), GDRAM2NRAM);
|
||
__bang_mul_scalar(nram_offsets, nram_offsets, rotary_stride * dtype_size, seq_block);
|
||
__sync_io_move_compute();
|
||
}
|
||
bool gather_table = (seq_offsets != nullptr && decoder_mode) || discrete;
|
||
|
||
int load_head_num = 0;
|
||
int calc_head_num = 0;
|
||
int store_head_num = 0;
|
||
for (int i = 0; i < loop_num + 2; i++) {
|
||
// store
|
||
if (i > 1) {
|
||
store_head_num = std::min(once_head_num, head_num - (i - 2) * once_head_num);
|
||
if ((head_size - rotary_dim) > 0) {
|
||
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride + rotary_dim,
|
||
input + (i - 2) * once_head_num * input_head_stride + rotary_dim,
|
||
(head_size - rotary_dim) * dtype_size, GDRAM2GDRAM,
|
||
output_seq_stride * dtype_size, seq_block - 1,
|
||
output_head_stride * dtype_size, store_head_num - 1,
|
||
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
|
||
store_head_num - 1);
|
||
}
|
||
float *nram_store = buffer + (i % 2) * block_head;
|
||
__memcpy_async(output + (i - 2) * once_head_num * output_head_stride, nram_store,
|
||
rotary_dim * dtype_size, NRAM2GDRAM, output_seq_stride * dtype_size,
|
||
seq_block - 1, output_head_stride * dtype_size, store_head_num - 1,
|
||
rotary_dim * dtype_size, seq_block - 1, seq_block * rotary_dim * dtype_size,
|
||
store_head_num - 1);
|
||
}
|
||
// load
|
||
float *temp_load = buffer + (i % 2) * block_head;
|
||
T *nram_load = (T *)((int8_t *)temp_load + (float_size - dtype_size) * block_head);
|
||
if (i < loop_num) {
|
||
load_head_num = std::min(once_head_num, head_num - i * once_head_num);
|
||
__memcpy_async(nram_load, input + i * once_head_num * input_head_stride,
|
||
rotary_dim * dtype_size, GDRAM2NRAM, rotary_dim * dtype_size, seq_block - 1,
|
||
seq_block * rotary_dim * dtype_size, load_head_num - 1,
|
||
input_seq_stride * dtype_size, seq_block - 1, input_head_stride * dtype_size,
|
||
load_head_num - 1);
|
||
}
|
||
if (i == 1) {
|
||
loadTableAsync(sincos_, sin_table, nram_offsets, rotary_dim, rotary_stride, seq_block,
|
||
seq_begin, dtype_size, gather_table, decoder_mode);
|
||
loadTableAsync(sincos_ + seq_rotary, cos_table, nram_offsets, rotary_dim, rotary_stride,
|
||
seq_block, seq_begin, dtype_size, gather_table, decoder_mode);
|
||
}
|
||
// compute
|
||
if (i > 0 && i < loop_num + 1) {
|
||
float *q_1 = buffer + ((i + 1) % 2) * block_head;
|
||
T *q_1_ = (T *)((int8_t *)q_1 + (float_size - dtype_size) * block_head);
|
||
calc_head_num = std::min(once_head_num, head_num - (i - 1) * once_head_num);
|
||
__memcpy_async(q_2_, q_1_ + rotary_dim / 2, rotary_dim / 2 * dtype_size, NRAM2NRAM,
|
||
rotary_dim * dtype_size, rotary_dim * dtype_size,
|
||
calc_head_num * seq_block - 1);
|
||
__memcpy_async(q_2_ + rotary_dim / 2, q_1_, rotary_dim / 2 * dtype_size, NRAM2NRAM,
|
||
rotary_dim * dtype_size, rotary_dim * dtype_size,
|
||
calc_head_num * seq_block - 1);
|
||
__sync_move();
|
||
|
||
toFloat(q_1, q_1_, block_head);
|
||
__bang_cycle_mul(q_2_, q_2_, (T *)nram_mask, block_head, rotary_dim);
|
||
toFloat(q_2, q_2_, block_head);
|
||
if (i == 1) {
|
||
__sync_io();
|
||
toFloat(sin, sincos_, seq_rotary * 2);
|
||
}
|
||
__bang_cycle_mul(q_1, q_1, cos, block_head, seq_rotary);
|
||
__bang_cycle_mul(q_2, q_2, sin, block_head, seq_rotary);
|
||
__bang_add(q_1, q_1, q_2, block_head);
|
||
floatTo((T *)q_1, q_1, block_head);
|
||
}
|
||
__sync_io_move_compute();
|
||
}
|
||
}
|
||
|
||
// [bs, seq_block]
|
||
template <typename T, bool interleaved>
|
||
__mlu_global__ void MluRotaryEmebdding(void *output,
|
||
const void *input,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const int *seq_offsets,
|
||
const int *cu_seq_lens,
|
||
int batch,
|
||
int max_seq_len,
|
||
int head_num,
|
||
int head_size,
|
||
int rotary_seq_len,
|
||
int rotary_dim,
|
||
int seq_once,
|
||
int rotary_stride,
|
||
size_t input_seq_stride,
|
||
size_t input_head_stride,
|
||
size_t output_seq_stride,
|
||
size_t output_head_stride,
|
||
bool discrete,
|
||
bool dynamic_ntk,
|
||
bool decoder_mode,
|
||
bool loop_head,
|
||
int once_head_num) {
|
||
initMask<T>(nram_mask, rotary_dim, interleaved);
|
||
|
||
int head_begin = taskIdX;
|
||
int head_per_task = taskDimX == 1 ? head_num : 1;
|
||
// decode mode little diff: no loop
|
||
if (decoder_mode) {
|
||
int task_begin_seq = taskIdY * seq_once;
|
||
int seq_block = std::min(batch - task_begin_seq, seq_once);
|
||
if (seq_block <= 0 || __is_mpu()) {
|
||
return;
|
||
}
|
||
size_t input_offset = task_begin_seq * input_seq_stride + head_begin * input_head_stride;
|
||
size_t output_offset = task_begin_seq * output_seq_stride + head_begin * output_head_stride;
|
||
T *input_begin = (T *)input + input_offset;
|
||
T *output_begin = (T *)output + output_offset;
|
||
if (interleaved) {
|
||
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
|
||
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
|
||
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
|
||
output_seq_stride, task_begin_seq, discrete, decoder_mode);
|
||
} else {
|
||
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table, (T *)cos_table,
|
||
(int *)seq_offsets, head_per_task, seq_block, head_size, rotary_dim,
|
||
rotary_stride, input_head_stride, input_seq_stride, output_head_stride,
|
||
output_seq_stride, task_begin_seq, discrete, decoder_mode, loop_head,
|
||
once_head_num);
|
||
}
|
||
return;
|
||
}
|
||
int seq_begin = cu_seq_lens == nullptr ? taskIdY * max_seq_len : cu_seq_lens[taskIdY];
|
||
int seq_len = cu_seq_lens == nullptr ? max_seq_len : cu_seq_lens[taskIdY + 1] - seq_begin;
|
||
|
||
for (int i = taskIdZ * seq_once; i < seq_len; i += taskDimZ * seq_once) {
|
||
int seq_block = std::min(seq_once, seq_len - i);
|
||
int global_seq_begin = seq_begin + i;
|
||
int seq_block_begin = i;
|
||
size_t input_offset = global_seq_begin * input_seq_stride + head_begin * input_head_stride;
|
||
size_t output_offset = global_seq_begin * output_seq_stride + head_begin * output_head_stride;
|
||
size_t bs_table_offset = dynamic_ntk ? (size_t)taskIdY * rotary_seq_len * rotary_stride : 0;
|
||
T *input_begin = (T *)input + input_offset;
|
||
T *output_begin = (T *)output + output_offset;
|
||
T *sin_table_begin = (T *)sin_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
|
||
T *cos_table_begin = (T *)cos_table + bs_table_offset + (size_t)seq_block_begin * rotary_stride;
|
||
if (seq_offsets != nullptr && !discrete) {
|
||
sin_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
|
||
cos_table_begin += seq_offsets[taskIdY] * (size_t)rotary_stride;
|
||
} else if (seq_offsets != nullptr && discrete) {
|
||
sin_table_begin = (T *)sin_table + bs_table_offset;
|
||
cos_table_begin = (T *)cos_table + bs_table_offset;
|
||
}
|
||
if (interleaved) {
|
||
crossRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
|
||
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
|
||
head_size, rotary_dim, rotary_stride, input_head_stride,
|
||
input_seq_stride, output_head_stride, output_seq_stride,
|
||
global_seq_begin, discrete, decoder_mode);
|
||
__sync_io_move_compute();
|
||
} else {
|
||
foldRotaryEmbedding((T *)output_begin, (T *)input_begin, (T *)sin_table_begin,
|
||
(T *)cos_table_begin, (int *)seq_offsets, head_per_task, seq_block,
|
||
head_size, rotary_dim, rotary_stride, input_head_stride, input_seq_stride,
|
||
output_head_stride, output_seq_stride, global_seq_begin, discrete,
|
||
decoder_mode, loop_head, once_head_num);
|
||
}
|
||
}
|
||
}
|
||
|
||
#if __BANG_ARCH__ < 592
|
||
template <>
|
||
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, true>(void *output,
|
||
const void *input,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const int *seq_offsets,
|
||
const int *cu_seq_lens,
|
||
int batch,
|
||
int max_seq_len,
|
||
int head_num,
|
||
int head_size,
|
||
int rotary_seq_len,
|
||
int rotary_dim,
|
||
int seq_once,
|
||
int rotary_stride,
|
||
size_t input_seq_stride,
|
||
size_t input_head_stride,
|
||
size_t output_seq_stride,
|
||
size_t output_head_stride,
|
||
bool discrete,
|
||
bool dynamic_ntk,
|
||
bool decoder_mode,
|
||
bool loop_head,
|
||
int once_head_num) {}
|
||
|
||
template <>
|
||
__mlu_global__ void MluRotaryEmebdding<bfloat16_t, false>(void *output,
|
||
const void *input,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const int *seq_offsets,
|
||
const int *cu_seq_lens,
|
||
int batch,
|
||
int max_seq_len,
|
||
int head_num,
|
||
int head_size,
|
||
int rotary_seq_len,
|
||
int rotary_dim,
|
||
int seq_once,
|
||
int rotary_stride,
|
||
size_t input_seq_stride,
|
||
size_t input_head_stride,
|
||
size_t output_seq_stride,
|
||
size_t output_head_stride,
|
||
bool discrete,
|
||
bool dynamic_ntk,
|
||
bool decoder_mode,
|
||
bool loop_head,
|
||
int once_head_num) {}
|
||
#endif
|
||
} // namespace kernels
|
||
|
||
KernelStatus invokeRotaryEmbedding(cnrtQueue_t queue,
|
||
void *output,
|
||
const void *input,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const int *seq_offsets,
|
||
const int *cu_seq_lens,
|
||
int batch,
|
||
int max_seq_len,
|
||
int head_num,
|
||
int head_size,
|
||
int rotary_seq_len,
|
||
int rotary_dim,
|
||
int rotary_stride,
|
||
size_t input_seq_stride,
|
||
size_t input_head_stride,
|
||
size_t output_seq_stride,
|
||
size_t output_head_stride,
|
||
bool interleaved,
|
||
bool discrete,
|
||
bool dynamic_ntk,
|
||
cnnlDataType_t data_type) {
|
||
void (*rotary_embedding_kernels[])(void *, /* output */
|
||
const void *, /* input */
|
||
const void *, /* sin_table */
|
||
const void *, /* cos_table */
|
||
const int *, /* seq_offsets */
|
||
const int *, /* cu_seq_lens */
|
||
int, /* batch */
|
||
int, /* max_seq_len */
|
||
int, /* head_num */
|
||
int, /* head_size */
|
||
int, /* rotary_seq_len */
|
||
int, /* rotary_dim */
|
||
int, /* seq_once */
|
||
int, /* rotary_stride */
|
||
size_t, /* input_seq_stride */
|
||
size_t, /* input_head_stride */
|
||
size_t, /* output_seq_stride */
|
||
size_t, /* output_head_stride */
|
||
bool, /* discrete, */
|
||
bool, /* dynamic_ntk */
|
||
bool, /* decoder_mode */
|
||
bool, /* loop_head */
|
||
int) /* once_head_num */
|
||
= {kernels::MluRotaryEmebdding<half, true>,
|
||
kernels::MluRotaryEmebdding<half, false>,
|
||
kernels::MluRotaryEmebdding<bfloat16_t, true>,
|
||
kernels::MluRotaryEmebdding<bfloat16_t, false>,
|
||
kernels::MluRotaryEmebdding<float, true>,
|
||
kernels::MluRotaryEmebdding<float, false>};
|
||
|
||
int kernel_index = 0;
|
||
if (data_type == CNNL_DTYPE_HALF) {
|
||
kernel_index = interleaved ? 0 : 1;
|
||
} else if (data_type == CNNL_DTYPE_BFLOAT16) {
|
||
kernel_index = interleaved ? 2 : 3;
|
||
} else if (data_type == CNNL_DTYPE_FLOAT) {
|
||
kernel_index = interleaved ? 4 : 5;
|
||
}
|
||
if (head_size > 256) {
|
||
std::cerr << "[invokeRotaryEmbedding]: only supported head_size <= 256, currently head_size = "
|
||
<< head_size << std::endl;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
CNdev dev;
|
||
cnCtxGetDevice(&dev);
|
||
int cluster_num;
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||
int core_num;
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||
int total_core_num = cluster_num * core_num;
|
||
|
||
uint32_t seq_once = data_type == CNNL_DTYPE_FLOAT ? (rotary_dim > 128 ? 64 : 128)
|
||
: (rotary_dim > 128 ? 128 : 256);
|
||
// decode场景,需要判断空间是否够,fold场景下最大限制为每个ipu处理64,cross限制为batch*head小于等于sq_once
|
||
int batch_per_core = (batch + total_core_num - 1) / total_core_num;
|
||
int batch_per_core_cap = 64;
|
||
bool batch_limit = interleaved ? (batch_per_core * head_num <= seq_once)
|
||
: (batch_per_core <= batch_per_core_cap);
|
||
bool decoder_mode = batch_limit && max_seq_len == 1 && dynamic_ntk == false;
|
||
|
||
bool do_one_head_per_task = (head_num > 32 && max_seq_len > 2048) || head_num > seq_once;
|
||
seq_once = do_one_head_per_task ? seq_once : seq_once / head_num;
|
||
|
||
// fold rotary做了流水,拆分有所不同。
|
||
bool loop_head = true;
|
||
int once_head_num = 1;
|
||
if (!interleaved) {
|
||
seq_once = rotary_dim > 128 ? 64 : 128;
|
||
// 小seq情况下,不够拆,需要减小seq_once
|
||
if (batch * (max_seq_len + seq_once - 1) / seq_once < total_core_num) {
|
||
seq_once = std::max(1, max_seq_len / (total_core_num / batch));
|
||
}
|
||
do_one_head_per_task = false;
|
||
// 判断decode场景能否一次性处理完所有head
|
||
if (decoder_mode) {
|
||
loop_head = false;
|
||
int nram_buffer_size = 480 * 1024;
|
||
int nram_input_size = batch_per_core * head_num * rotary_dim * sizeof(float);
|
||
int nram_q2_size = batch_per_core * head_num * rotary_dim * sizeof(float);
|
||
int nram_table_size = batch_per_core * rotary_dim * sizeof(float) * 2;
|
||
int total_nram_size = nram_input_size + nram_q2_size + nram_table_size;
|
||
loop_head = total_nram_size > nram_buffer_size;
|
||
if (loop_head) {
|
||
// 如果需要循环,则重新计算每次处理多少头
|
||
once_head_num = (nram_buffer_size - nram_table_size) /
|
||
(batch_per_core * rotary_dim * sizeof(float) * 3);
|
||
}
|
||
// rebalance
|
||
int loop_num = (head_num + once_head_num - 1) / once_head_num;
|
||
once_head_num = (head_num + loop_num - 1) / loop_num;
|
||
}
|
||
}
|
||
|
||
uint32_t seq_segments = ((uint32_t)max_seq_len + seq_once - 1) / seq_once;
|
||
uint32_t task_dimx = do_one_head_per_task ? head_num : 1;
|
||
uint32_t task_dimz = total_core_num > seq_segments ? seq_segments : total_core_num;
|
||
uint32_t task_dimy =
|
||
decoder_mode && !do_one_head_per_task ? (uint32_t)total_core_num : (uint32_t)batch;
|
||
seq_once = decoder_mode ? (batch + task_dimy - 1) / task_dimy : seq_once;
|
||
cnrtDim3_t dim = {task_dimx, task_dimy, task_dimz};
|
||
|
||
if (data_type == CNNL_DTYPE_BFLOAT16 && !isBf16Supported()) {
|
||
std::cerr << "[invokeRotaryEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
|
||
rotary_embedding_kernels[kernel_index]<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||
output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch, max_seq_len, head_num,
|
||
head_size, rotary_seq_len, rotary_dim, seq_once, rotary_stride, input_seq_stride,
|
||
input_head_stride, output_seq_stride, output_head_stride, discrete, dynamic_ntk, decoder_mode,
|
||
loop_head, once_head_num);
|
||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||
}
|
||
|
||
KernelStatus invokeGlm6BRotaryEmbedding(cnrtQueue_t queue,
|
||
void *output,
|
||
const void *input,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const int *seq_offsets,
|
||
const int *cu_seq_lens,
|
||
int batch,
|
||
int max_seq_len,
|
||
int total_seq_len,
|
||
int head_num,
|
||
int head_size,
|
||
int rotary_seq_len,
|
||
int rotary_stride,
|
||
size_t input_seq_stride,
|
||
size_t input_head_stride,
|
||
size_t output_seq_stride,
|
||
size_t output_head_stride,
|
||
bool interleaved,
|
||
cnnlDataType_t data_type) {
|
||
size_t type_size = 0;
|
||
cnnlGetSizeOfDataType(data_type, &type_size);
|
||
invokeRotaryEmbedding(queue, output, input, sin_table, cos_table, seq_offsets, cu_seq_lens, batch,
|
||
max_seq_len, head_num, head_size / 2, rotary_seq_len, head_size / 2,
|
||
rotary_stride, input_seq_stride, input_head_stride, output_seq_stride,
|
||
output_head_stride, interleaved, true, false, data_type);
|
||
|
||
invokeRotaryEmbedding(queue, (int8_t *)output + head_size / 2 * type_size,
|
||
(int8_t *)input + head_size / 2 * type_size, sin_table, cos_table,
|
||
seq_offsets + total_seq_len, cu_seq_lens, batch, max_seq_len, head_num,
|
||
head_size / 2, rotary_seq_len, head_size / 2, rotary_stride,
|
||
input_seq_stride, input_head_stride, output_seq_stride, output_head_stride,
|
||
interleaved, true, false, data_type);
|
||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||
}
|
||
} // namespace tmo
|