131 lines
5.1 KiB
Plaintext
131 lines
5.1 KiB
Plaintext
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <ostream>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "generate_alibi_slope.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
|
|
namespace kernels {
|
|
|
|
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
|
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
|
|
|
__nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
|
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
|
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
|
|
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64};
|
|
|
|
__nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25,
|
|
27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51,
|
|
53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77,
|
|
79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103,
|
|
105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127};
|
|
|
|
__mlu_func__ void genRange(float *range_nram,
|
|
float *range_base_nram,
|
|
int fill_num,
|
|
int base_num,
|
|
int offset = 0) {
|
|
int loop = (fill_num + base_num - 1) / base_num;
|
|
for (int i = 0; i < loop; i++) {
|
|
int num = std::min((fill_num - i * base_num), base_num);
|
|
float *fill_nram = range_nram + i * base_num;
|
|
__bang_move(fill_nram, range_base_nram, num * sizeof(float));
|
|
__bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num);
|
|
}
|
|
}
|
|
|
|
__mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes,
|
|
int *true_seq_lens,
|
|
int batch_num,
|
|
int head_start,
|
|
int head_num,
|
|
int head_num_total,
|
|
int max_sequence_length,
|
|
bool use_dynamic,
|
|
int closest_power_of_2,
|
|
int farthest_power_of_2,
|
|
float base,
|
|
float extra_base) {
|
|
float *range_nram = (float *)nram_buffer;
|
|
float *base_nram = range_nram + head_num;
|
|
float *slope_nram = base_nram + head_num;
|
|
|
|
float scale = 1.0;
|
|
float dynamic_base = base;
|
|
if (use_dynamic) {
|
|
float a0 = 1.0;
|
|
float a = a0 * true_seq_lens[taskIdX] / max_sequence_length;
|
|
a = std::max(a, 1.0f);
|
|
scale = powf(a, (1.0 / (head_num_total - 1)));
|
|
dynamic_base = base / scale;
|
|
}
|
|
|
|
int close_head_num = 0;
|
|
if (head_start >= closest_power_of_2) {
|
|
close_head_num = 0;
|
|
} else if (head_start + head_num <= closest_power_of_2) {
|
|
close_head_num = head_num;
|
|
} else {
|
|
close_head_num = closest_power_of_2 - head_start;
|
|
}
|
|
int far_head_num = head_num - close_head_num;
|
|
|
|
// fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1
|
|
if (close_head_num) {
|
|
genRange(range_nram, range_1, close_head_num, 64, head_start);
|
|
__bang_write_value(base_nram, close_head_num, dynamic_base);
|
|
}
|
|
if (far_head_num) {
|
|
genRange(range_nram + close_head_num, range_2, far_head_num, 64,
|
|
(head_start + close_head_num - closest_power_of_2) * 2);
|
|
__bang_write_value(base_nram + close_head_num, far_head_num, extra_base);
|
|
}
|
|
|
|
// base_nram ** range_nram
|
|
__bang_log(base_nram, base_nram, head_num);
|
|
__bang_mul(slope_nram, base_nram, range_nram, head_num);
|
|
__bang_pow2(slope_nram, slope_nram, head_num);
|
|
|
|
if (use_dynamic) {
|
|
__bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num);
|
|
}
|
|
|
|
__memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM);
|
|
}
|
|
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue,
|
|
void *alibi_slopes,
|
|
void *true_seq_lens,
|
|
int batch_num,
|
|
int head_start,
|
|
int head_num,
|
|
int head_num_total,
|
|
int max_sequence_length,
|
|
bool use_dynamic) {
|
|
cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1};
|
|
|
|
int closest_power_of_2 = pow(2, floor(log2(head_num_total)));
|
|
int farthest_power_of_2 = closest_power_of_2 * 2;
|
|
float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3))));
|
|
float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3))));
|
|
|
|
kernels::MLUAlibiSlopeKernel<<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total,
|
|
max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base);
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|