/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "cnnl.h" #include "cnrt.h" #include "fused_rope.mluh" // clang-format off #include // clang-format on using bf16 = bfloat16_t; namespace tmo { namespace kernels { #ifndef PAD_UP #define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y)) #endif #if __BANG_ARCH__ > 500 #include template using bang_cycle_fusor = bang::experimental::cycle_fusor; #endif #define NRAM_BUFFER_SIZE (480 * 1024) __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __nram__ float nram_mask[256]; __nram__ int nram_rope_offsets[256]; __nram__ float nram_zeros[1024] = {0.f}; template __mlu_func__ void toFloat(float *dst, T *src, int num) { if (std::is_same::value) { __bang_half2float(dst, (half *)src, num); } else if (std::is_same::value) { #if __BANG_ARCH__ > 500 __bang_bfloat162float(dst, (bf16 *)src, num); #endif } } template __mlu_func__ void floatTo(T *dst, float *src, int num) { if (std::is_same::value) { __bang_float2half_rn((half *)dst, src, num); } else if (std::is_same::value) { __bang_float2bfloat16_rn((bf16 *)dst, src, num); } } __mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin, int *cache_seq_offsets_begin, int *slot_mapping_begin, int *nram_k_cache_offsets, int *nram_v_cache_offsets, int *nram_v_onchip_offsets, int *nram_kv_scale_offsets, float *nram_cache_mask, float *nram_zeros, float *nram_temp, int task_deal_batch, int task_begin_batch, int head_num_k, int head_size, int max_decode_len, int block_size, int kv_out_size, int group_num, bool discrete_batch, bool paged_cache, bool mixed_cache) { // 目前先用标量化计算offset,便于理解(性能无影响) int bh = task_deal_batch * head_num_k; if (paged_cache) { int cache_seq_stride = head_size; int cache_head_stride = block_size * head_size; int cache_scale_head_stride = block_size * group_num; int cache_block_stride = head_num_k * cache_head_stride; int cache_scale_block_stride = head_num_k * cache_scale_head_stride; int *nram_slot_mapping = (int *)nram_mask; __memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); for (int i = 0; i < task_deal_batch; i++) { int mapping_idx = __load_nram(nram_slot_mapping + i); if (mapping_idx < 0) { __bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1); continue; } int block_idx = mapping_idx / block_size; int seq_idx = mapping_idx % block_size; int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride; int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride; int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num; int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size; for (int j = 0; j < head_num_k; j++) { __store_nram( nram_k_cache_offsets + i * head_num_k + j, (int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1))); if (mixed_cache) { __store_nram(nram_v_cache_offsets + i * head_num_k + j, (int)(v_seq_offset + j * cache_head_stride / 2)); __store_nram(nram_kv_scale_offsets + i * head_num_k + j, (int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float))); __store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size); } } } } else { int *nram_seq_offsets = (int *)nram_mask; int *nram_bs_id = nram_seq_offsets + 32; int cache_seq_stride = head_size; int cache_head_stride = max_decode_len * head_size; int cache_scale_head_stride = max_decode_len * group_num; int cache_bs_stride = head_num_k * cache_head_stride; int cache_scale_bs_stride = head_num_k * cache_scale_head_stride; __memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); if (discrete_batch) { __memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM); } for (int i = 0; i < task_deal_batch; i++) { int bs_idx = __load_nram(nram_bs_id + i); int seq_idx = __load_nram(nram_seq_offsets + i); int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i; int temp_seq_idx = seq_idx; bool masked = temp_bs_idx < 0 || temp_seq_idx < 0; if (masked) { __bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1); continue; } int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride; int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num; int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride; int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size; for (int j = 0; j < head_num_k; j++) { __store_nram( nram_k_cache_offsets + i * head_num_k + j, (int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1))); if (mixed_cache) { __store_nram(nram_v_cache_offsets + i * head_num_k + j, (int)(v_seq_offset + j * cache_head_stride / 2)); __store_nram(nram_kv_scale_offsets + i * head_num_k + j, (int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float))); __store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size); } } } } // 此处是为了做上scatter指令的 mask,如果bs offset或seq offset小于0则需要mask掉 __bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0); __bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8)); } __mlu_func__ void layernormImpl(float *nram_k, float *norm_params, int task_deal_batch, int k_hidden, float eps) { #if __BANG_ARCH__ > 500 float *buffer = nram_k + task_deal_batch * k_hidden; for (int i = 0; i < task_deal_batch; i++) { float *k_ = nram_k + i * k_hidden; __bang_mul(buffer, k_, k_, k_hidden); float mean = __bang_sum(k_, k_hidden); mean = mean / k_hidden; float rstd = __bang_sum(buffer, k_hidden); rstd = rstd / k_hidden - mean * mean; rstd = rstd < 0 ? eps : rstd + eps; rstd = 1.f / std::sqrt(rstd); __bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden); } __bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden, task_deal_batch * k_hidden, k_hidden); #endif } __mlu_func__ void foldRotaryImpl(float *nram_qk, float *nram_qk_rot, float *nram_table, int task_deal_batch, int head_num_qk, int head_size) { int rotary_low_dim = task_deal_batch * head_size; __bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim); __bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size); __bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size, head_num_qk * rotary_low_dim, rotary_low_dim); __bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim); } template __mlu_func__ void quantify(T *input, float *float_input, void *output_hp, void *output_lp, float *nram_trans, float *scale_hp, float *scale_lp, float *scale_lp_temp, int batch, int head_num, int head_size, int group_num, int group_size, bool quant_kv_hp, bool mixed_cache) { if (quant_kv_hp) { int hidden = head_num * head_size; int bh = batch * head_num; toFloat(float_input, input, batch * hidden); __bang_recip(scale_hp, scale_hp, hidden); #if __BANG_ARCH__ > 500 __asm__ __volatile__( "fuse.nram.crn.s8.f32 " "[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])" ";\n\t" ::[dst] "r"(output_hp), [num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input), [src1] "r"(scale_hp), [pos] "i"(0)); #endif if (mixed_cache) { __bang_transpose(nram_trans, float_input, bh * group_num, group_size); __bang_abs(float_input, nram_trans, bh * head_size); __bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1, 1); __bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num); __bang_recip(scale_lp_temp, scale_lp, bh * group_num); __bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size, bh * group_num); __bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0); __bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num); } } } template __mlu_func__ void fuseRopeImpl(T *input, void *key_cache_hp, void *value_cache_hp, void *key_cache_lp, void *value_cache_lp, T *sin_table, T *cos_table, int *rope_offsets, T *gamma, T *beta, float *key_scale_hp, float *value_scale_hp, float *key_scale_lp, float *value_scale_lp, int *cache_bs_id_hp, int *cache_seq_offsets_hp, int *cache_bs_id_lp, int *cache_seq_offsets_lp, int *slot_mapping_hp, int *slot_mapping_lp, int rotary_stride, int task_deal_batch, int task_begin_batch, int head_num_q, int head_num_k, int head_size, int max_decode_len_hp, int max_decode_len_lp, int block_size_hp, int block_size_lp, int group_size, int batch_cap, float eps) { #if __BANG_ARCH__ > 500 /* 由于需要支持mixed cache,kernel支持cache的功能组合比较多,现规定只存在以下几种: 1.只存在hp_cache的情况(通过lp tensor不为0判断),cache支持bf16,fp16,量化下支持离线perchannel int8 支持linear和paged,key和value cache形状一致,key/value_scale_hp形状为[head_num, head_size] 2.mixed cache的情况hp支持离线perchannel int8量化,支持linear和paged,key和value cache形状一致, key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化, key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2] paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2, head_size],paged cache形状为 [num_blocks, head_num_k, block_size / 2, head_size],key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp, group_num],paged_cache 为 [num_blocks, head_num_k, block_size, group_num] */ bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr; bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr; bool discrete_batch_hp = cache_bs_id_hp != nullptr; bool discrete_batch_lp = cache_bs_id_lp != nullptr; bool paged_cache_hp = slot_mapping_hp != nullptr; bool paged_cache_lp = slot_mapping_lp != nullptr; int head_num_qk = head_num_q + head_num_k; int head_num_qkv = head_num_q + head_num_k * 2; int qkv_hidden = head_num_qkv * head_size; int qk_hidden = head_num_qk * head_size; int q_hidden = head_num_q * head_size; int k_hidden = head_num_k * head_size; int float_size = sizeof(float); int dtype_size = sizeof(T); int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size; int group_num = mixed_cache ? head_size / group_size : 1; // task ddr offset T *input_begin = input + task_begin_batch * qkv_hidden; int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch; int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch; int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch; // nram_buffer float *nram_qk = (float *)nram_buffer; float *nram_qk_rot = nram_qk + batch_cap * qk_hidden; float *nram_v = nram_qk_rot + batch_cap * qk_hidden; float *nram_kv_trans = nram_v + batch_cap * k_hidden; float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden; float *norm_params = nram_table + 2 * batch_cap * head_size; float *nram_k_scale_hp = norm_params + 2 * head_size; float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden; float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden; float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num; int8_t *nram_kv_hp = (int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num); int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden; int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden; int *nram_kv_cache_offsets_hp = (int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2); int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k; int *nram_v_cache_offsets_lp = nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k; int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k; int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k; float *cache_mask_hp = (float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k); float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8); // 这里将qk和qk_rot放在一起是为升位宽可以一起做,减少指令,同样还有sincostable和norm的gamma和beta T *qk_in = (T *)nram_qk_rot; T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden); T *v_in = (T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden); T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size); T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size); int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t); // 生成 kv cache的offset和mask,供scatter kv到kvcache使用 genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp, nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp, nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k, head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1, discrete_batch_hp, paged_cache_hp, false); if (mixed_cache) { int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch; int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch; int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch; genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp, nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets, nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp, 1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache); } /* ----------------------- load v | ----------------------- load qk | quant v ----------------------- store v | rope qk ----------------------- store_q | layernorm k | quant k ----------------------- store k | */ // prepare v v_scale cache_v rope_offset __memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM, k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1); if (quant_kv_hp) { __memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM); __memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM); } __memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int), GDRAM2NRAM); __sync_io(); if (mixed_cache) { __gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp, head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t), task_deal_batch * head_num_k); } __bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size, task_deal_batch); __sync_compute(); /*==============================================================================================*/ // load_qk,rope_table | quant v __memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1, qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size, head_num_qk - 1); __gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch); __gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch); __memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM); __memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM); int8_t *nram_temp = (int8_t *)nram_qk; if (mixed_cache) { __bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0, 0); __bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2); } quantify(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp, nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num, group_size, quant_kv_hp, mixed_cache); if (mixed_cache) { __scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp, head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t), task_deal_batch * head_num_k); __bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden); __bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0, 0); } __sync_io_move_compute(); /*==============================================================================================*/ // rope | store v // 将qk的左右部分交换,用于生成qk_rot __memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM, head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1); __memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM, head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1); toFloat(nram_qk, qk_in, 2 * batch_cap * qk_hidden); toFloat(nram_table, table_in, 2 * task_deal_batch * head_size); toFloat(norm_params, norm_params_in, 2 * head_size); __bang_write_value(nram_mask, head_size / 2, (float)-1); __bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1); foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size); floatTo((T *)nram_qk, nram_qk, task_deal_batch * q_hidden); int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v; __scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp, head_num_k * task_deal_batch); if (mixed_cache) { __scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM, head_size * sizeof(int8_t), head_num_k * task_deal_batch); __scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float), head_num_k * task_deal_batch); } __sync_io_move_compute(); /*==============================================================================================*/ // layernrom k quant k | store q // 从qk的nram buffer中提取出k做layernorm和量化 float *nram_k = nram_qk_rot; __memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM, head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1, task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size, task_deal_batch - 1); layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps); quantify(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp, nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num, group_size, quant_kv_hp, mixed_cache); if (mixed_cache) { __bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0); } if (!quant_kv_hp) { floatTo((T *)nram_k, nram_k, task_deal_batch * k_hidden); } // store q __memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size, head_num_q - 1, head_size * dtype_size, task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_q - 1); // =============================================================================================== int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k; __scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp, head_num_k * task_deal_batch); if (mixed_cache) { __scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp, head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t), head_num_k * task_deal_batch); __scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float), head_num_k * task_deal_batch); } __sync_io_move_compute(); #endif } template __mlu_global__ void MLUFuseRope(T *input, void *key_cache_hp, void *value_cache_hp, void *key_cache_lp, void *value_cache_lp, T *sin_table, T *cos_table, int *rope_offsets, T *gamma, T *beta, float *key_scale_hp, float *value_scale_hp, float *key_scale_lp, float *value_scale_lp, int *cache_bs_id_hp, int *cache_seq_offsets_hp, int *cache_bs_id_lp, int *cache_seq_offsets_lp, int *slot_mapping_hp, int *slot_mapping_lp, int rotary_stride, int batch, int head_num_q, int head_num_k, int head_size, int max_decode_len_hp, int max_decode_len_lp, int block_size_hp, int block_size_lp, int group_size, int batch_cap, int task_avg_batch, float eps) { int task_begin_batch = taskId * task_avg_batch; int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch); if (task_deal_batch <= 0 || __is_mpu()) { return; } int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap; int once_batch = (task_deal_batch + task_loop - 1) / task_loop; for (int i = 0; i < task_loop; i++) { int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch); int batch_offset = task_begin_batch + once_batch * i; fuseRopeImpl(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table, cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp, key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp, cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp, rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size, max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, once_batch, eps); } } } // namespace kernels KernelStatus invokeFusedRope(cnrtQueue_t queue, void *input, void *key_cache_hp, void *value_cache_hp, void *key_cache_lp, void *value_cache_lp, const void *sin_table, const void *cos_table, const void *rope_offsets, const void *gamma, const void *beta, const void *key_scale_hp, const void *value_scale_hp, void *key_scale_lp, void *value_scale_lp, const void *cache_bs_id_hp, const void *cache_seq_offsets_hp, const void *cache_bs_id_lp, const void *cache_seq_offsets_lp, const void *slot_mapping_hp, const void *slot_mapping_lp, int rotary_stride, int batch_size, int head_num_q, int head_num_kv, int head_size, int max_decode_len_hp, int max_decode_len_lp, int block_size_hp, int block_size_lp, int group_size, cnnlDataType_t dtype, float eps) { if (is_arch300()) { std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } CNdev dev; cnCtxGetDevice(&dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); uint32_t taskdimx = cluster_num * core_num; int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx; int float_size = sizeof(float); int group_num = head_size / group_size; bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr; bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr; int nram_avalible_bytes = 480 * 1024; int task_max_batch = 32; int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1); int nram_params_bytes = 2 * head_size * float_size; int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size; int nram_remain_bytes = nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes; int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2; int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1); int nram_table_bytes = 2 * head_size * float_size; int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size; int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size; int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size; int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2; int nram_cache_offsets_hp = head_num_kv * sizeof(int); int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int); int batch_cap = nram_remain_bytes / (nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes + nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp); batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch; cnrtDim3_t dim{taskdimx, 1, 1}; if (dtype == CNNL_DTYPE_HALF) { kernels::MLUFuseRope<<>>( (half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, (half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta, (float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp, (float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp, (int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp, (int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size, max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap, task_avg_batch, eps); } else if (dtype == CNNL_DTYPE_BFLOAT16) { kernels::MLUFuseRope<<>>( (bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, (bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta, (float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp, (float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp, (int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp, (int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size, max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap, task_avg_batch, eps); } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo