#include #include #include #include #include #include #include "cnnl.h" #include "cnrt.h" #include "generate_alibi_slope.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { #define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) __nram__ int8_t nram_buffer[NRAM_SIZE]; __nram__ float range_1[64] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}; __nram__ float range_2[64] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; __mlu_func__ void genRange(float *range_nram, float *range_base_nram, int fill_num, int base_num, int offset = 0) { int loop = (fill_num + base_num - 1) / base_num; for (int i = 0; i < loop; i++) { int num = std::min((fill_num - i * base_num), base_num); float *fill_nram = range_nram + i * base_num; __bang_move(fill_nram, range_base_nram, num * sizeof(float)); __bang_add_scalar(fill_nram, fill_nram, i * base_num + offset, num); } } __mlu_global__ void MLUAlibiSlopeKernel(float *alibi_slopes, int *true_seq_lens, int batch_num, int head_start, int head_num, int head_num_total, int max_sequence_length, bool use_dynamic, int closest_power_of_2, int farthest_power_of_2, float base, float extra_base) { float *range_nram = (float *)nram_buffer; float *base_nram = range_nram + head_num; float *slope_nram = base_nram + head_num; float scale = 1.0; float dynamic_base = base; if (use_dynamic) { float a0 = 1.0; float a = a0 * true_seq_lens[taskIdX] / max_sequence_length; a = std::max(a, 1.0f); scale = powf(a, (1.0 / (head_num_total - 1))); dynamic_base = base / scale; } int close_head_num = 0; if (head_start >= closest_power_of_2) { close_head_num = 0; } else if (head_start + head_num <= closest_power_of_2) { close_head_num = head_num; } else { close_head_num = closest_power_of_2 - head_start; } int far_head_num = head_num - close_head_num; // fill range: 1, 2..., n1, 1, 3, (n - n1) * 2 - 1 if (close_head_num) { genRange(range_nram, range_1, close_head_num, 64, head_start); __bang_write_value(base_nram, close_head_num, dynamic_base); } if (far_head_num) { genRange(range_nram + close_head_num, range_2, far_head_num, 64, (head_start + close_head_num - closest_power_of_2) * 2); __bang_write_value(base_nram + close_head_num, far_head_num, extra_base); } // base_nram ** range_nram __bang_log(base_nram, base_nram, head_num); __bang_mul(slope_nram, base_nram, range_nram, head_num); __bang_pow2(slope_nram, slope_nram, head_num); if (use_dynamic) { __bang_mul_scalar(slope_nram, slope_nram, scale, close_head_num); } __memcpy(alibi_slopes + taskIdX * head_num, slope_nram, head_num * sizeof(float), NRAM2GDRAM); } } // namespace kernels KernelStatus invokeGenerateAlibiSlope(cnrtQueue_t queue, void *alibi_slopes, void *true_seq_lens, int batch_num, int head_start, int head_num, int head_num_total, int max_sequence_length, bool use_dynamic) { cnrtDim3_t dim{.x = (uint32_t)batch_num, .y = 1, .z = 1}; int closest_power_of_2 = pow(2, floor(log2(head_num_total))); int farthest_power_of_2 = closest_power_of_2 * 2; float base = pow(2, (-pow(2, -(log2(closest_power_of_2) - 3)))); float extra_base = pow(2, (-pow(2, -(log2(2 * closest_power_of_2) - 3)))); kernels::MLUAlibiSlopeKernel<<>>( (float *)alibi_slopes, (int *)true_seq_lens, batch_num, head_start, head_num, head_num_total, max_sequence_length, use_dynamic, closest_power_of_2, farthest_power_of_2, base, extra_base); return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo