Files
2026-02-04 17:39:32 +08:00

659 lines
33 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "cnnl.h"
#include "cnrt.h"
#include "fused_rope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
using bf16 = bfloat16_t;
namespace tmo {
namespace kernels {
#ifndef PAD_UP
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
#endif
#if __BANG_ARCH__ > 500
#include <bang_fusor.h>
template <typename T>
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
#endif
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_mask[256];
__nram__ int nram_rope_offsets[256];
__nram__ float nram_zeros[1024] = {0.f};
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int num) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, num);
} else if (std::is_same<T, bf16>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(dst, (bf16 *)src, num);
#endif
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int num) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, num);
} else if (std::is_same<T, bf16>::value) {
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
}
}
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
int *cache_seq_offsets_begin,
int *slot_mapping_begin,
int *nram_k_cache_offsets,
int *nram_v_cache_offsets,
int *nram_v_onchip_offsets,
int *nram_kv_scale_offsets,
float *nram_cache_mask,
float *nram_zeros,
float *nram_temp,
int task_deal_batch,
int task_begin_batch,
int head_num_k,
int head_size,
int max_decode_len,
int block_size,
int kv_out_size,
int group_num,
bool discrete_batch,
bool paged_cache,
bool mixed_cache) {
// 目前先用标量化计算offset便于理解(性能无影响)
int bh = task_deal_batch * head_num_k;
if (paged_cache) {
int cache_seq_stride = head_size;
int cache_head_stride = block_size * head_size;
int cache_scale_head_stride = block_size * group_num;
int cache_block_stride = head_num_k * cache_head_stride;
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
int *nram_slot_mapping = (int *)nram_mask;
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
for (int i = 0; i < task_deal_batch; i++) {
int mapping_idx = __load_nram(nram_slot_mapping + i);
if (mapping_idx < 0) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int block_idx = mapping_idx / block_size;
int seq_idx = mapping_idx % block_size;
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
} else {
int *nram_seq_offsets = (int *)nram_mask;
int *nram_bs_id = nram_seq_offsets + 32;
int cache_seq_stride = head_size;
int cache_head_stride = max_decode_len * head_size;
int cache_scale_head_stride = max_decode_len * group_num;
int cache_bs_stride = head_num_k * cache_head_stride;
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
if (discrete_batch) {
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
}
for (int i = 0; i < task_deal_batch; i++) {
int bs_idx = __load_nram(nram_bs_id + i);
int seq_idx = __load_nram(nram_seq_offsets + i);
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
int temp_seq_idx = seq_idx;
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
if (masked) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
}
// 此处是为了做上scatter指令的 mask如果bs offset或seq offset小于0则需要mask掉
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
}
__mlu_func__ void layernormImpl(float *nram_k,
float *norm_params,
int task_deal_batch,
int k_hidden,
float eps) {
#if __BANG_ARCH__ > 500
float *buffer = nram_k + task_deal_batch * k_hidden;
for (int i = 0; i < task_deal_batch; i++) {
float *k_ = nram_k + i * k_hidden;
__bang_mul(buffer, k_, k_, k_hidden);
float mean = __bang_sum(k_, k_hidden);
mean = mean / k_hidden;
float rstd = __bang_sum(buffer, k_hidden);
rstd = rstd / k_hidden - mean * mean;
rstd = rstd < 0 ? eps : rstd + eps;
rstd = 1.f / std::sqrt(rstd);
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
}
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
task_deal_batch * k_hidden, k_hidden);
#endif
}
__mlu_func__ void foldRotaryImpl(float *nram_qk,
float *nram_qk_rot,
float *nram_table,
int task_deal_batch,
int head_num_qk,
int head_size) {
int rotary_low_dim = task_deal_batch * head_size;
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
}
template <typename T>
__mlu_func__ void quantify(T *input,
float *float_input,
void *output_hp,
void *output_lp,
float *nram_trans,
float *scale_hp,
float *scale_lp,
float *scale_lp_temp,
int batch,
int head_num,
int head_size,
int group_num,
int group_size,
bool quant_kv_hp,
bool mixed_cache) {
if (quant_kv_hp) {
int hidden = head_num * head_size;
int bh = batch * head_num;
toFloat<T>(float_input, input, batch * hidden);
__bang_recip(scale_hp, scale_hp, hidden);
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"fuse.nram.crn.s8.f32 "
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
";\n\t" ::[dst] "r"(output_hp),
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
[src1] "r"(scale_hp), [pos] "i"(0));
#endif
if (mixed_cache) {
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
__bang_abs(float_input, nram_trans, bh * head_size);
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
1);
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
bh * group_num);
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
}
}
}
template <typename T>
__mlu_func__ void fuseRopeImpl(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int task_deal_batch,
int task_begin_batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
float eps) {
#if __BANG_ARCH__ > 500
/*
由于需要支持mixed cachekernel支持cache的功能组合比较多现规定只存在以下几种
1.只存在hp_cache的情况(通过lp tensor不为0判断)cache支持bf16fp16量化下支持离线perchannel int8
支持linear和pagedkey和value cache形状一致key/value_scale_hp形状为[head_num, head_size]
2.mixed cache的情况hp支持离线perchannel int8量化支持linear和pagedkey和value cache形状一致
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
head_size]paged cache形状为 [num_blocks, head_num_k, block_size / 2,
head_size]key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
group_num]paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
*/
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
bool paged_cache_hp = slot_mapping_hp != nullptr;
bool paged_cache_lp = slot_mapping_lp != nullptr;
int head_num_qk = head_num_q + head_num_k;
int head_num_qkv = head_num_q + head_num_k * 2;
int qkv_hidden = head_num_qkv * head_size;
int qk_hidden = head_num_qk * head_size;
int q_hidden = head_num_q * head_size;
int k_hidden = head_num_k * head_size;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
int group_num = mixed_cache ? head_size / group_size : 1;
// task ddr offset
T *input_begin = input + task_begin_batch * qkv_hidden;
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
// nram_buffer
float *nram_qk = (float *)nram_buffer;
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
float *norm_params = nram_table + 2 * batch_cap * head_size;
float *nram_k_scale_hp = norm_params + 2 * head_size;
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
int8_t *nram_kv_hp =
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
int *nram_kv_cache_offsets_hp =
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
int *nram_v_cache_offsets_lp =
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
float *cache_mask_hp =
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
// 这里将qk和qk_rot放在一起是为升位宽可以一起做减少指令同样还有sincostable和norm的gamma和beta
T *qk_in = (T *)nram_qk_rot;
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
T *v_in =
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
// 生成 kv cache的offset和mask供scatter kv到kvcache使用
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
discrete_batch_hp, paged_cache_hp, false);
if (mixed_cache) {
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
}
/*
-----------------------
load v |
-----------------------
load qk | quant v
-----------------------
store v | rope qk
-----------------------
store_q | layernorm k
| quant k
-----------------------
store k |
*/
// prepare v v_scale cache_v rope_offset
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
if (quant_kv_hp) {
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
}
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
GDRAM2NRAM);
__sync_io();
if (mixed_cache) {
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
}
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
task_deal_batch);
__sync_compute();
/*==============================================================================================*/
// load_qk,rope_table | quant v
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
head_num_qk - 1);
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
int8_t *nram_temp = (int8_t *)nram_qk;
if (mixed_cache) {
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
0);
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
}
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
0);
}
__sync_io_move_compute();
/*==============================================================================================*/
// rope | store v
// 将qk的左右部分交换用于生成qk_rot
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
__bang_write_value(nram_mask, head_size / 2, (float)-1);
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
/*==============================================================================================*/
// layernrom k quant k | store q
// 从qk的nram buffer中提取出k做layernorm和量化
float *nram_k = nram_qk_rot;
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
task_deal_batch - 1);
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
}
if (!quant_kv_hp) {
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
}
// store q
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
head_size * dtype_size, task_deal_batch - 1,
task_deal_batch * head_size * dtype_size, head_num_q - 1);
// ===============================================================================================
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
head_num_k * task_deal_batch);
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
#endif
}
template <typename T>
__mlu_global__ void MLUFuseRope(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
int task_avg_batch,
float eps) {
int task_begin_batch = taskId * task_avg_batch;
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
if (task_deal_batch <= 0 || __is_mpu()) {
return;
}
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
for (int i = 0; i < task_loop; i++) {
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
int batch_offset = task_begin_batch + once_batch * i;
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
once_batch, eps);
}
}
} // namespace kernels
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps) {
if (is_arch300()) {
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
uint32_t taskdimx = cluster_num * core_num;
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
int float_size = sizeof(float);
int group_num = head_size / group_size;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
int nram_avalible_bytes = 480 * 1024;
int task_max_batch = 32;
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
int nram_params_bytes = 2 * head_size * float_size;
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
int nram_remain_bytes =
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
int nram_table_bytes = 2 * head_size * float_size;
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
int batch_cap =
nram_remain_bytes /
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
cnrtDim3_t dim{taskdimx, 1, 1};
if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo