Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/fused_rope.mlu

659 lines
33 KiB
Plaintext
Raw Normal View History

2026-02-04 17:39:32 +08:00
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "cnnl.h"
#include "cnrt.h"
#include "fused_rope.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
using bf16 = bfloat16_t;
namespace tmo {
namespace kernels {
#ifndef PAD_UP
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
#endif
#if __BANG_ARCH__ > 500
#include <bang_fusor.h>
template <typename T>
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
#endif
#define NRAM_BUFFER_SIZE (480 * 1024)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ float nram_mask[256];
__nram__ int nram_rope_offsets[256];
__nram__ float nram_zeros[1024] = {0.f};
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int num) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, num);
} else if (std::is_same<T, bf16>::value) {
#if __BANG_ARCH__ > 500
__bang_bfloat162float(dst, (bf16 *)src, num);
#endif
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int num) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, num);
} else if (std::is_same<T, bf16>::value) {
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
}
}
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
int *cache_seq_offsets_begin,
int *slot_mapping_begin,
int *nram_k_cache_offsets,
int *nram_v_cache_offsets,
int *nram_v_onchip_offsets,
int *nram_kv_scale_offsets,
float *nram_cache_mask,
float *nram_zeros,
float *nram_temp,
int task_deal_batch,
int task_begin_batch,
int head_num_k,
int head_size,
int max_decode_len,
int block_size,
int kv_out_size,
int group_num,
bool discrete_batch,
bool paged_cache,
bool mixed_cache) {
// 目前先用标量化计算offset便于理解(性能无影响)
int bh = task_deal_batch * head_num_k;
if (paged_cache) {
int cache_seq_stride = head_size;
int cache_head_stride = block_size * head_size;
int cache_scale_head_stride = block_size * group_num;
int cache_block_stride = head_num_k * cache_head_stride;
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
int *nram_slot_mapping = (int *)nram_mask;
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
for (int i = 0; i < task_deal_batch; i++) {
int mapping_idx = __load_nram(nram_slot_mapping + i);
if (mapping_idx < 0) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int block_idx = mapping_idx / block_size;
int seq_idx = mapping_idx % block_size;
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
} else {
int *nram_seq_offsets = (int *)nram_mask;
int *nram_bs_id = nram_seq_offsets + 32;
int cache_seq_stride = head_size;
int cache_head_stride = max_decode_len * head_size;
int cache_scale_head_stride = max_decode_len * group_num;
int cache_bs_stride = head_num_k * cache_head_stride;
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
if (discrete_batch) {
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
}
for (int i = 0; i < task_deal_batch; i++) {
int bs_idx = __load_nram(nram_bs_id + i);
int seq_idx = __load_nram(nram_seq_offsets + i);
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
int temp_seq_idx = seq_idx;
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
if (masked) {
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
continue;
}
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
for (int j = 0; j < head_num_k; j++) {
__store_nram(
nram_k_cache_offsets + i * head_num_k + j,
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
if (mixed_cache) {
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
(int)(v_seq_offset + j * cache_head_stride / 2));
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
}
}
}
}
// 此处是为了做上scatter指令的 mask如果bs offset或seq offset小于0则需要mask掉
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
}
__mlu_func__ void layernormImpl(float *nram_k,
float *norm_params,
int task_deal_batch,
int k_hidden,
float eps) {
#if __BANG_ARCH__ > 500
float *buffer = nram_k + task_deal_batch * k_hidden;
for (int i = 0; i < task_deal_batch; i++) {
float *k_ = nram_k + i * k_hidden;
__bang_mul(buffer, k_, k_, k_hidden);
float mean = __bang_sum(k_, k_hidden);
mean = mean / k_hidden;
float rstd = __bang_sum(buffer, k_hidden);
rstd = rstd / k_hidden - mean * mean;
rstd = rstd < 0 ? eps : rstd + eps;
rstd = 1.f / std::sqrt(rstd);
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
}
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
task_deal_batch * k_hidden, k_hidden);
#endif
}
__mlu_func__ void foldRotaryImpl(float *nram_qk,
float *nram_qk_rot,
float *nram_table,
int task_deal_batch,
int head_num_qk,
int head_size) {
int rotary_low_dim = task_deal_batch * head_size;
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
head_num_qk * rotary_low_dim, rotary_low_dim);
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
}
template <typename T>
__mlu_func__ void quantify(T *input,
float *float_input,
void *output_hp,
void *output_lp,
float *nram_trans,
float *scale_hp,
float *scale_lp,
float *scale_lp_temp,
int batch,
int head_num,
int head_size,
int group_num,
int group_size,
bool quant_kv_hp,
bool mixed_cache) {
if (quant_kv_hp) {
int hidden = head_num * head_size;
int bh = batch * head_num;
toFloat<T>(float_input, input, batch * hidden);
__bang_recip(scale_hp, scale_hp, hidden);
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"fuse.nram.crn.s8.f32 "
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
";\n\t" ::[dst] "r"(output_hp),
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
[src1] "r"(scale_hp), [pos] "i"(0));
#endif
if (mixed_cache) {
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
__bang_abs(float_input, nram_trans, bh * head_size);
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
1);
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
bh * group_num);
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
}
}
}
template <typename T>
__mlu_func__ void fuseRopeImpl(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int task_deal_batch,
int task_begin_batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
float eps) {
#if __BANG_ARCH__ > 500
/*
由于需要支持mixed cachekernel支持cache的功能组合比较多现规定只存在以下几种
1.只存在hp_cache的情况(通过lp tensor不为0判断)cache支持bf16fp16量化下支持离线perchannel int8
支持linear和pagedkey和value cache形状一致key/value_scale_hp形状为[head_num, head_size]
2.mixed cache的情况hp支持离线perchannel int8量化支持linear和pagedkey和value cache形状一致
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
head_size]paged cache形状为 [num_blocks, head_num_k, block_size / 2,
head_size]key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
group_num]paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
*/
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
bool paged_cache_hp = slot_mapping_hp != nullptr;
bool paged_cache_lp = slot_mapping_lp != nullptr;
int head_num_qk = head_num_q + head_num_k;
int head_num_qkv = head_num_q + head_num_k * 2;
int qkv_hidden = head_num_qkv * head_size;
int qk_hidden = head_num_qk * head_size;
int q_hidden = head_num_q * head_size;
int k_hidden = head_num_k * head_size;
int float_size = sizeof(float);
int dtype_size = sizeof(T);
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
int group_num = mixed_cache ? head_size / group_size : 1;
// task ddr offset
T *input_begin = input + task_begin_batch * qkv_hidden;
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
// nram_buffer
float *nram_qk = (float *)nram_buffer;
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
float *norm_params = nram_table + 2 * batch_cap * head_size;
float *nram_k_scale_hp = norm_params + 2 * head_size;
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
int8_t *nram_kv_hp =
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
int *nram_kv_cache_offsets_hp =
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
int *nram_v_cache_offsets_lp =
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
float *cache_mask_hp =
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
// 这里将qk和qk_rot放在一起是为升位宽可以一起做减少指令同样还有sincostable和norm的gamma和beta
T *qk_in = (T *)nram_qk_rot;
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
T *v_in =
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
// 生成 kv cache的offset和mask供scatter kv到kvcache使用
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
discrete_batch_hp, paged_cache_hp, false);
if (mixed_cache) {
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
}
/*
-----------------------
load v |
-----------------------
load qk | quant v
-----------------------
store v | rope qk
-----------------------
store_q | layernorm k
| quant k
-----------------------
store k |
*/
// prepare v v_scale cache_v rope_offset
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
if (quant_kv_hp) {
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
}
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
GDRAM2NRAM);
__sync_io();
if (mixed_cache) {
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
}
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
task_deal_batch);
__sync_compute();
/*==============================================================================================*/
// load_qk,rope_table | quant v
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
head_num_qk - 1);
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
int8_t *nram_temp = (int8_t *)nram_qk;
if (mixed_cache) {
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
0);
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
}
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
task_deal_batch * head_num_k);
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
0);
}
__sync_io_move_compute();
/*==============================================================================================*/
// rope | store v
// 将qk的左右部分交换用于生成qk_rot
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
__bang_write_value(nram_mask, head_size / 2, (float)-1);
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
/*==============================================================================================*/
// layernrom k quant k | store q
// 从qk的nram buffer中提取出k做layernorm和量化
float *nram_k = nram_qk_rot;
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
task_deal_batch - 1);
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
if (mixed_cache) {
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
}
if (!quant_kv_hp) {
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
}
// store q
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
head_size * dtype_size, task_deal_batch - 1,
task_deal_batch * head_size * dtype_size, head_num_q - 1);
// ===============================================================================================
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
head_num_k * task_deal_batch);
if (mixed_cache) {
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
head_num_k * task_deal_batch);
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
head_num_k * task_deal_batch);
}
__sync_io_move_compute();
#endif
}
template <typename T>
__mlu_global__ void MLUFuseRope(T *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
T *sin_table,
T *cos_table,
int *rope_offsets,
T *gamma,
T *beta,
float *key_scale_hp,
float *value_scale_hp,
float *key_scale_lp,
float *value_scale_lp,
int *cache_bs_id_hp,
int *cache_seq_offsets_hp,
int *cache_bs_id_lp,
int *cache_seq_offsets_lp,
int *slot_mapping_hp,
int *slot_mapping_lp,
int rotary_stride,
int batch,
int head_num_q,
int head_num_k,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
int batch_cap,
int task_avg_batch,
float eps) {
int task_begin_batch = taskId * task_avg_batch;
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
if (task_deal_batch <= 0 || __is_mpu()) {
return;
}
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
for (int i = 0; i < task_loop; i++) {
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
int batch_offset = task_begin_batch + once_batch * i;
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
once_batch, eps);
}
}
} // namespace kernels
KernelStatus invokeFusedRope(cnrtQueue_t queue,
void *input,
void *key_cache_hp,
void *value_cache_hp,
void *key_cache_lp,
void *value_cache_lp,
const void *sin_table,
const void *cos_table,
const void *rope_offsets,
const void *gamma,
const void *beta,
const void *key_scale_hp,
const void *value_scale_hp,
void *key_scale_lp,
void *value_scale_lp,
const void *cache_bs_id_hp,
const void *cache_seq_offsets_hp,
const void *cache_bs_id_lp,
const void *cache_seq_offsets_lp,
const void *slot_mapping_hp,
const void *slot_mapping_lp,
int rotary_stride,
int batch_size,
int head_num_q,
int head_num_kv,
int head_size,
int max_decode_len_hp,
int max_decode_len_lp,
int block_size_hp,
int block_size_lp,
int group_size,
cnnlDataType_t dtype,
float eps) {
if (is_arch300()) {
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
uint32_t taskdimx = cluster_num * core_num;
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
int float_size = sizeof(float);
int group_num = head_size / group_size;
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
int nram_avalible_bytes = 480 * 1024;
int task_max_batch = 32;
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
int nram_params_bytes = 2 * head_size * float_size;
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
int nram_remain_bytes =
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
int nram_table_bytes = 2 * head_size * float_size;
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
int batch_cap =
nram_remain_bytes /
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
cnrtDim3_t dim{taskdimx, 1, 1};
if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
task_avg_batch, eps);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo