forked from EngineX-Cambricon/enginex-mlu370-vllm
659 lines
33 KiB
Plaintext
659 lines
33 KiB
Plaintext
/*************************************************************************
|
||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||
*
|
||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||
*************************************************************************/
|
||
#include "cnnl.h"
|
||
#include "cnrt.h"
|
||
#include "fused_rope.mluh"
|
||
// clang-format off
|
||
#include <mlu.h>
|
||
// clang-format on
|
||
|
||
using bf16 = bfloat16_t;
|
||
|
||
namespace tmo {
|
||
|
||
namespace kernels {
|
||
|
||
#ifndef PAD_UP
|
||
#define PAD_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0) * (y))
|
||
#endif
|
||
|
||
#if __BANG_ARCH__ > 500
|
||
#include <bang_fusor.h>
|
||
template <typename T>
|
||
using bang_cycle_fusor = bang::experimental::cycle_fusor<T>;
|
||
#endif
|
||
|
||
#define NRAM_BUFFER_SIZE (480 * 1024)
|
||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||
__nram__ float nram_mask[256];
|
||
__nram__ int nram_rope_offsets[256];
|
||
__nram__ float nram_zeros[1024] = {0.f};
|
||
|
||
template <typename T>
|
||
__mlu_func__ void toFloat(float *dst, T *src, int num) {
|
||
if (std::is_same<T, half>::value) {
|
||
__bang_half2float(dst, (half *)src, num);
|
||
} else if (std::is_same<T, bf16>::value) {
|
||
#if __BANG_ARCH__ > 500
|
||
__bang_bfloat162float(dst, (bf16 *)src, num);
|
||
#endif
|
||
}
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void floatTo(T *dst, float *src, int num) {
|
||
if (std::is_same<T, half>::value) {
|
||
__bang_float2half_rn((half *)dst, src, num);
|
||
} else if (std::is_same<T, bf16>::value) {
|
||
__bang_float2bfloat16_rn((bf16 *)dst, src, num);
|
||
}
|
||
}
|
||
|
||
__mlu_func__ void genScatterOffsetMask(int *cache_bs_id_begin,
|
||
int *cache_seq_offsets_begin,
|
||
int *slot_mapping_begin,
|
||
int *nram_k_cache_offsets,
|
||
int *nram_v_cache_offsets,
|
||
int *nram_v_onchip_offsets,
|
||
int *nram_kv_scale_offsets,
|
||
float *nram_cache_mask,
|
||
float *nram_zeros,
|
||
float *nram_temp,
|
||
int task_deal_batch,
|
||
int task_begin_batch,
|
||
int head_num_k,
|
||
int head_size,
|
||
int max_decode_len,
|
||
int block_size,
|
||
int kv_out_size,
|
||
int group_num,
|
||
bool discrete_batch,
|
||
bool paged_cache,
|
||
bool mixed_cache) {
|
||
// 目前先用标量化计算offset,便于理解(性能无影响)
|
||
int bh = task_deal_batch * head_num_k;
|
||
if (paged_cache) {
|
||
int cache_seq_stride = head_size;
|
||
int cache_head_stride = block_size * head_size;
|
||
int cache_scale_head_stride = block_size * group_num;
|
||
int cache_block_stride = head_num_k * cache_head_stride;
|
||
int cache_scale_block_stride = head_num_k * cache_scale_head_stride;
|
||
int *nram_slot_mapping = (int *)nram_mask;
|
||
__memcpy(nram_slot_mapping, slot_mapping_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||
for (int i = 0; i < task_deal_batch; i++) {
|
||
int mapping_idx = __load_nram(nram_slot_mapping + i);
|
||
if (mapping_idx < 0) {
|
||
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
|
||
continue;
|
||
}
|
||
int block_idx = mapping_idx / block_size;
|
||
int seq_idx = mapping_idx % block_size;
|
||
int k_seq_offset = block_idx * cache_block_stride + seq_idx * cache_seq_stride;
|
||
int v_seq_offset = block_idx * cache_block_stride / 2 + seq_idx / 2 * cache_seq_stride;
|
||
int scale_seq_offset = block_idx * cache_scale_block_stride + seq_idx * group_num;
|
||
int onchip_offset = i * head_num_k * head_size + seq_idx % 2 * bh * head_size;
|
||
|
||
for (int j = 0; j < head_num_k; j++) {
|
||
__store_nram(
|
||
nram_k_cache_offsets + i * head_num_k + j,
|
||
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
|
||
if (mixed_cache) {
|
||
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
|
||
(int)(v_seq_offset + j * cache_head_stride / 2));
|
||
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
|
||
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
|
||
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
int *nram_seq_offsets = (int *)nram_mask;
|
||
int *nram_bs_id = nram_seq_offsets + 32;
|
||
int cache_seq_stride = head_size;
|
||
int cache_head_stride = max_decode_len * head_size;
|
||
int cache_scale_head_stride = max_decode_len * group_num;
|
||
int cache_bs_stride = head_num_k * cache_head_stride;
|
||
int cache_scale_bs_stride = head_num_k * cache_scale_head_stride;
|
||
__memcpy(nram_seq_offsets, cache_seq_offsets_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||
if (discrete_batch) {
|
||
__memcpy(nram_bs_id, cache_bs_id_begin, task_deal_batch * sizeof(int), GDRAM2NRAM);
|
||
}
|
||
for (int i = 0; i < task_deal_batch; i++) {
|
||
int bs_idx = __load_nram(nram_bs_id + i);
|
||
int seq_idx = __load_nram(nram_seq_offsets + i);
|
||
int temp_bs_idx = discrete_batch ? bs_idx : task_begin_batch + i;
|
||
int temp_seq_idx = seq_idx;
|
||
bool masked = temp_bs_idx < 0 || temp_seq_idx < 0;
|
||
if (masked) {
|
||
__bang_write_value(nram_k_cache_offsets + i * head_num_k, head_num_k, (int)-1);
|
||
continue;
|
||
}
|
||
int k_seq_offset = temp_bs_idx * cache_bs_stride + temp_seq_idx * cache_seq_stride;
|
||
int scale_seq_offset = temp_bs_idx * cache_scale_bs_stride + temp_seq_idx * group_num;
|
||
int v_seq_offset = temp_bs_idx * cache_bs_stride / 2 + temp_seq_idx / 2 * cache_seq_stride;
|
||
int onchip_offset = i * head_num_k * head_size + temp_seq_idx % 2 * bh * head_size;
|
||
|
||
for (int j = 0; j < head_num_k; j++) {
|
||
__store_nram(
|
||
nram_k_cache_offsets + i * head_num_k + j,
|
||
(int)((k_seq_offset + j * cache_head_stride) * kv_out_size / (mixed_cache + 1)));
|
||
if (mixed_cache) {
|
||
__store_nram(nram_v_cache_offsets + i * head_num_k + j,
|
||
(int)(v_seq_offset + j * cache_head_stride / 2));
|
||
__store_nram(nram_kv_scale_offsets + i * head_num_k + j,
|
||
(int)((scale_seq_offset + j * cache_scale_head_stride) * sizeof(float)));
|
||
__store_nram(nram_v_onchip_offsets + i * head_num_k + j, onchip_offset + j * head_size);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 此处是为了做上scatter指令的 mask,如果bs offset或seq offset小于0则需要mask掉
|
||
__bang_int322float(nram_temp, nram_k_cache_offsets, bh, 0);
|
||
__bang_ge_bitindex(nram_cache_mask, nram_temp, nram_zeros, PAD_UP(bh, 8));
|
||
}
|
||
|
||
__mlu_func__ void layernormImpl(float *nram_k,
|
||
float *norm_params,
|
||
int task_deal_batch,
|
||
int k_hidden,
|
||
float eps) {
|
||
#if __BANG_ARCH__ > 500
|
||
float *buffer = nram_k + task_deal_batch * k_hidden;
|
||
for (int i = 0; i < task_deal_batch; i++) {
|
||
float *k_ = nram_k + i * k_hidden;
|
||
__bang_mul(buffer, k_, k_, k_hidden);
|
||
float mean = __bang_sum(k_, k_hidden);
|
||
mean = mean / k_hidden;
|
||
float rstd = __bang_sum(buffer, k_hidden);
|
||
rstd = rstd / k_hidden - mean * mean;
|
||
rstd = rstd < 0 ? eps : rstd + eps;
|
||
rstd = 1.f / std::sqrt(rstd);
|
||
__bang_fusion(FUSION_FSM, k_, k_, mean, rstd, k_hidden);
|
||
}
|
||
__bang_fusion(FUSION_FMA, nram_k, nram_k, norm_params, norm_params + k_hidden,
|
||
task_deal_batch * k_hidden, k_hidden);
|
||
#endif
|
||
}
|
||
|
||
__mlu_func__ void foldRotaryImpl(float *nram_qk,
|
||
float *nram_qk_rot,
|
||
float *nram_table,
|
||
int task_deal_batch,
|
||
int head_num_qk,
|
||
int head_size) {
|
||
int rotary_low_dim = task_deal_batch * head_size;
|
||
__bang_cycle_mul(nram_qk, nram_qk, nram_table, head_num_qk * rotary_low_dim, rotary_low_dim);
|
||
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_mask, head_num_qk * rotary_low_dim, head_size);
|
||
__bang_cycle_mul(nram_qk_rot, nram_qk_rot, nram_table + task_deal_batch * head_size,
|
||
head_num_qk * rotary_low_dim, rotary_low_dim);
|
||
__bang_add(nram_qk, nram_qk, nram_qk_rot, head_num_qk * rotary_low_dim);
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void quantify(T *input,
|
||
float *float_input,
|
||
void *output_hp,
|
||
void *output_lp,
|
||
float *nram_trans,
|
||
float *scale_hp,
|
||
float *scale_lp,
|
||
float *scale_lp_temp,
|
||
int batch,
|
||
int head_num,
|
||
int head_size,
|
||
int group_num,
|
||
int group_size,
|
||
bool quant_kv_hp,
|
||
bool mixed_cache) {
|
||
if (quant_kv_hp) {
|
||
int hidden = head_num * head_size;
|
||
int bh = batch * head_num;
|
||
toFloat<T>(float_input, input, batch * hidden);
|
||
__bang_recip(scale_hp, scale_hp, hidden);
|
||
#if __BANG_ARCH__ > 500
|
||
__asm__ __volatile__(
|
||
"fuse.nram.crn.s8.f32 "
|
||
"[%[dst]], %[num_long], %[num_short], [%[src0]], .mul.cycle([%[src1]]), .dstpos(%[pos])"
|
||
";\n\t" ::[dst] "r"(output_hp),
|
||
[num_long] "r"(batch * hidden), [num_short] "r"(hidden), [src0] "r"(float_input),
|
||
[src1] "r"(scale_hp), [pos] "i"(0));
|
||
#endif
|
||
if (mixed_cache) {
|
||
__bang_transpose(nram_trans, float_input, bh * group_num, group_size);
|
||
__bang_abs(float_input, nram_trans, bh * head_size);
|
||
__bang_maxpool(scale_lp_temp, float_input, bh * group_num, group_size, 1, group_size, 1, 1,
|
||
1);
|
||
__bang_mul_scalar(scale_lp, scale_lp_temp, 1 / 7.f, bh * group_num);
|
||
__bang_recip(scale_lp_temp, scale_lp, bh * group_num);
|
||
__bang_cycle_mul(nram_trans, nram_trans, scale_lp_temp, bh * group_num * group_size,
|
||
bh * group_num);
|
||
__bang_float2int8_rn((int8_t *)nram_trans, nram_trans, bh * head_size, 0);
|
||
__bang_transpose((int8_t *)output_lp, (int8_t *)nram_trans, group_size, bh * group_num);
|
||
}
|
||
}
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_func__ void fuseRopeImpl(T *input,
|
||
void *key_cache_hp,
|
||
void *value_cache_hp,
|
||
void *key_cache_lp,
|
||
void *value_cache_lp,
|
||
T *sin_table,
|
||
T *cos_table,
|
||
int *rope_offsets,
|
||
T *gamma,
|
||
T *beta,
|
||
float *key_scale_hp,
|
||
float *value_scale_hp,
|
||
float *key_scale_lp,
|
||
float *value_scale_lp,
|
||
int *cache_bs_id_hp,
|
||
int *cache_seq_offsets_hp,
|
||
int *cache_bs_id_lp,
|
||
int *cache_seq_offsets_lp,
|
||
int *slot_mapping_hp,
|
||
int *slot_mapping_lp,
|
||
int rotary_stride,
|
||
int task_deal_batch,
|
||
int task_begin_batch,
|
||
int head_num_q,
|
||
int head_num_k,
|
||
int head_size,
|
||
int max_decode_len_hp,
|
||
int max_decode_len_lp,
|
||
int block_size_hp,
|
||
int block_size_lp,
|
||
int group_size,
|
||
int batch_cap,
|
||
float eps) {
|
||
#if __BANG_ARCH__ > 500
|
||
/*
|
||
由于需要支持mixed cache,kernel支持cache的功能组合比较多,现规定只存在以下几种:
|
||
1.只存在hp_cache的情况(通过lp tensor不为0判断),cache支持bf16,fp16,量化下支持离线perchannel int8
|
||
支持linear和paged,key和value cache形状一致,key/value_scale_hp形状为[head_num, head_size]
|
||
2.mixed cache的情况hp支持离线perchannel int8量化,支持linear和paged,key和value cache形状一致,
|
||
key/value_scale_hp 形状为[head_num, head_size]. lp支持int4在线pertoken group量化,
|
||
key_cache形状为 [batch, head_num_k, max_decode_len_lp, head_size / 2]
|
||
paged情况也是head_size / 2, value_cache的形状为[batch, head_num_l, max_decode_len_lp / 2,
|
||
head_size],paged cache形状为 [num_blocks, head_num_k, block_size / 2,
|
||
head_size],key/value_scale_lp形状为 [batch, head_num_k, max_decode_len_lp,
|
||
group_num],paged_cache 为 [num_blocks, head_num_k, block_size, group_num]
|
||
*/
|
||
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
|
||
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
|
||
bool discrete_batch_hp = cache_bs_id_hp != nullptr;
|
||
bool discrete_batch_lp = cache_bs_id_lp != nullptr;
|
||
bool paged_cache_hp = slot_mapping_hp != nullptr;
|
||
bool paged_cache_lp = slot_mapping_lp != nullptr;
|
||
int head_num_qk = head_num_q + head_num_k;
|
||
int head_num_qkv = head_num_q + head_num_k * 2;
|
||
int qkv_hidden = head_num_qkv * head_size;
|
||
int qk_hidden = head_num_qk * head_size;
|
||
int q_hidden = head_num_q * head_size;
|
||
int k_hidden = head_num_k * head_size;
|
||
int float_size = sizeof(float);
|
||
int dtype_size = sizeof(T);
|
||
int kv_size_hp = quant_kv_hp ? sizeof(int8_t) : dtype_size;
|
||
int group_num = mixed_cache ? head_size / group_size : 1;
|
||
|
||
// task ddr offset
|
||
T *input_begin = input + task_begin_batch * qkv_hidden;
|
||
int *cache_bs_id_begin_hp = cache_bs_id_hp + task_begin_batch;
|
||
int *cache_seq_offsets_begin_hp = cache_seq_offsets_hp + task_begin_batch;
|
||
int *slot_mapping_begin_hp = slot_mapping_hp + task_begin_batch;
|
||
|
||
// nram_buffer
|
||
float *nram_qk = (float *)nram_buffer;
|
||
float *nram_qk_rot = nram_qk + batch_cap * qk_hidden;
|
||
float *nram_v = nram_qk_rot + batch_cap * qk_hidden;
|
||
float *nram_kv_trans = nram_v + batch_cap * k_hidden;
|
||
float *nram_table = nram_kv_trans + (int)mixed_cache * batch_cap * k_hidden;
|
||
float *norm_params = nram_table + 2 * batch_cap * head_size;
|
||
float *nram_k_scale_hp = norm_params + 2 * head_size;
|
||
float *nram_v_scale_hp = nram_k_scale_hp + (int)quant_kv_hp * k_hidden;
|
||
float *nram_k_scale_lp = nram_v_scale_hp + (int)quant_kv_hp * k_hidden;
|
||
float *nram_v_scale_lp = nram_k_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num;
|
||
int8_t *nram_kv_hp =
|
||
(int8_t *)(nram_v_scale_lp + (int)mixed_cache * batch_cap * head_num_k * group_num);
|
||
int8_t *nram_kv_lp = nram_kv_hp + (int)quant_kv_hp * batch_cap * k_hidden;
|
||
int8_t *nram_cache_v = nram_kv_lp + (int)mixed_cache * batch_cap * k_hidden;
|
||
int *nram_kv_cache_offsets_hp =
|
||
(int *)(nram_cache_v + (int)mixed_cache * batch_cap * k_hidden * 2);
|
||
int *nram_k_cache_offsets_lp = nram_kv_cache_offsets_hp + batch_cap * head_num_k;
|
||
int *nram_v_cache_offsets_lp =
|
||
nram_k_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
|
||
int *nram_kv_scale_offsets = nram_v_cache_offsets_lp + (int)mixed_cache * batch_cap * head_num_k;
|
||
int *nram_v_onchip_offsets = nram_kv_scale_offsets + (int)mixed_cache * batch_cap * head_num_k;
|
||
float *cache_mask_hp =
|
||
(float *)(nram_v_onchip_offsets + (int)mixed_cache * batch_cap * head_num_k);
|
||
float *cache_mask_lp = (float *)((int8_t *)cache_mask_hp + PAD_UP(batch_cap * head_num_k, 8) / 8);
|
||
|
||
// 这里将qk和qk_rot放在一起是为升位宽可以一起做,减少指令,同样还有sincostable和norm的gamma和beta
|
||
T *qk_in = (T *)nram_qk_rot;
|
||
T *qk_rot_in = (T *)((int8_t *)nram_qk_rot + (float_size - dtype_size) * batch_cap * qk_hidden);
|
||
T *v_in =
|
||
(T *)((int8_t *)nram_v + (int)quant_kv_hp * (float_size - dtype_size) * batch_cap * k_hidden);
|
||
T *norm_params_in = (T *)((int8_t *)norm_params + (float_size - dtype_size) * 2 * head_size);
|
||
T *table_in = (T *)((int8_t *)nram_table + (float_size - dtype_size) * 2 * batch_cap * head_size);
|
||
int8_t *nram_cache_v_in = nram_cache_v + batch_cap * k_hidden * sizeof(int8_t);
|
||
|
||
// 生成 kv cache的offset和mask,供scatter kv到kvcache使用
|
||
genScatterOffsetMask(cache_bs_id_begin_hp, cache_seq_offsets_begin_hp, slot_mapping_begin_hp,
|
||
nram_kv_cache_offsets_hp, nullptr, nullptr, nullptr, cache_mask_hp,
|
||
nram_zeros, nram_qk, task_deal_batch, task_begin_batch, head_num_k,
|
||
head_size, max_decode_len_hp, block_size_hp, kv_size_hp, 1,
|
||
discrete_batch_hp, paged_cache_hp, false);
|
||
if (mixed_cache) {
|
||
int *cache_bs_id_begin_lp = cache_bs_id_lp + task_begin_batch;
|
||
int *cache_seq_offsets_begin_lp = cache_seq_offsets_lp + task_begin_batch;
|
||
int *slot_mapping_begin_lp = slot_mapping_lp + task_begin_batch;
|
||
genScatterOffsetMask(cache_bs_id_begin_lp, cache_seq_offsets_begin_lp, slot_mapping_begin_lp,
|
||
nram_k_cache_offsets_lp, nram_v_cache_offsets_lp, nram_v_onchip_offsets,
|
||
nram_kv_scale_offsets, cache_mask_lp, nram_zeros, nram_qk, task_deal_batch,
|
||
task_begin_batch, head_num_k, head_size, max_decode_len_lp, block_size_lp,
|
||
1, group_num, discrete_batch_lp, paged_cache_lp, mixed_cache);
|
||
}
|
||
|
||
/*
|
||
-----------------------
|
||
load v |
|
||
-----------------------
|
||
load qk | quant v
|
||
-----------------------
|
||
store v | rope qk
|
||
-----------------------
|
||
store_q | layernorm k
|
||
| quant k
|
||
-----------------------
|
||
store k |
|
||
*/
|
||
// prepare v v_scale cache_v rope_offset
|
||
__memcpy_async(v_in, input_begin + qk_hidden, k_hidden * dtype_size, GDRAM2NRAM,
|
||
k_hidden * dtype_size, qkv_hidden * dtype_size, task_deal_batch - 1);
|
||
if (quant_kv_hp) {
|
||
__memcpy_async(nram_k_scale_hp, key_scale_hp, k_hidden * float_size, GDRAM2NRAM);
|
||
__memcpy_async(nram_v_scale_hp, value_scale_hp, k_hidden * float_size, GDRAM2NRAM);
|
||
}
|
||
__memcpy_async(nram_rope_offsets, rope_offsets + task_begin_batch, task_deal_batch * sizeof(int),
|
||
GDRAM2NRAM);
|
||
|
||
__sync_io();
|
||
if (mixed_cache) {
|
||
__gather(nram_cache_v_in, value_cache_lp, (uint32_t *)nram_v_cache_offsets_lp, cache_mask_lp,
|
||
head_size * sizeof(int8_t), GDRAM2NRAM, head_size * sizeof(int8_t),
|
||
task_deal_batch * head_num_k);
|
||
}
|
||
__bang_mul_scalar(nram_rope_offsets, nram_rope_offsets, rotary_stride * dtype_size,
|
||
task_deal_batch);
|
||
__sync_compute();
|
||
/*==============================================================================================*/
|
||
|
||
// load_qk,rope_table | quant v
|
||
__memcpy_async(qk_in, input_begin, head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size,
|
||
task_deal_batch - 1, task_deal_batch * head_size * dtype_size, head_num_qk - 1,
|
||
qkv_hidden * dtype_size, task_deal_batch - 1, head_size * dtype_size,
|
||
head_num_qk - 1);
|
||
__gather_async(table_in, cos_table, (uint32_t *)nram_rope_offsets, head_size * dtype_size,
|
||
GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
|
||
__gather_async(table_in + task_deal_batch * head_size, sin_table, (uint32_t *)nram_rope_offsets,
|
||
head_size * dtype_size, GDRAM2NRAM, head_size * dtype_size, task_deal_batch);
|
||
__memcpy_async(norm_params_in, gamma, head_size * dtype_size, GDRAM2NRAM);
|
||
__memcpy_async(norm_params_in + head_size, beta, head_size * dtype_size, GDRAM2NRAM);
|
||
|
||
int8_t *nram_temp = (int8_t *)nram_qk;
|
||
if (mixed_cache) {
|
||
__bang_int42int8(nram_cache_v, (int4x2_t *)nram_cache_v_in, task_deal_batch * k_hidden * 2, 0,
|
||
0);
|
||
__bang_transpose(nram_temp, nram_cache_v, task_deal_batch * k_hidden, 2);
|
||
}
|
||
quantify<T>(v_in, nram_v, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_v_scale_hp, nram_v_scale_lp,
|
||
nram_k_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k, head_size, group_num,
|
||
group_size, quant_kv_hp, mixed_cache);
|
||
if (mixed_cache) {
|
||
__scatter(nram_temp, nram_kv_lp, (uint32_t *)nram_v_onchip_offsets, cache_mask_lp,
|
||
head_size * sizeof(int8_t), NRAM2NRAM, head_size * sizeof(int8_t),
|
||
task_deal_batch * head_num_k);
|
||
__bang_transpose(nram_cache_v, nram_temp, 2, task_deal_batch * k_hidden);
|
||
__bang_int82int4_rn((int4x2_t *)nram_cache_v, nram_cache_v, task_deal_batch * k_hidden * 2, 0,
|
||
0);
|
||
}
|
||
__sync_io_move_compute();
|
||
/*==============================================================================================*/
|
||
|
||
// rope | store v
|
||
// 将qk的左右部分交换,用于生成qk_rot
|
||
__memcpy(qk_rot_in, qk_in + head_size / 2, head_size / 2 * dtype_size, NRAM2NRAM,
|
||
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
|
||
__memcpy(qk_rot_in + head_size / 2, qk_in, head_size / 2 * dtype_size, NRAM2NRAM,
|
||
head_size * dtype_size, head_size * dtype_size, task_deal_batch * head_num_qk - 1);
|
||
toFloat<T>(nram_qk, qk_in, 2 * batch_cap * qk_hidden);
|
||
toFloat<T>(nram_table, table_in, 2 * task_deal_batch * head_size);
|
||
toFloat<T>(norm_params, norm_params_in, 2 * head_size);
|
||
__bang_write_value(nram_mask, head_size / 2, (float)-1);
|
||
__bang_write_value(nram_mask + head_size / 2, head_size / 2, (float)1);
|
||
foldRotaryImpl(nram_qk, nram_qk_rot, nram_table, task_deal_batch, head_num_qk, head_size);
|
||
floatTo<T>((T *)nram_qk, nram_qk, task_deal_batch * q_hidden);
|
||
|
||
int8_t *scatter_v_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_v;
|
||
__scatter_async(value_cache_hp, scatter_v_src, (uint32_t *)nram_kv_cache_offsets_hp,
|
||
cache_mask_hp, head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
|
||
head_num_k * task_deal_batch);
|
||
if (mixed_cache) {
|
||
__scatter_async(value_cache_lp, nram_cache_v, (uint32_t *)nram_v_cache_offsets_lp,
|
||
cache_mask_lp, head_size * sizeof(int8_t), NRAM2GDRAM,
|
||
head_size * sizeof(int8_t), head_num_k * task_deal_batch);
|
||
__scatter_async(value_scale_lp, nram_v_scale_lp, (uint32_t *)nram_kv_scale_offsets,
|
||
cache_mask_lp, group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
|
||
head_num_k * task_deal_batch);
|
||
}
|
||
__sync_io_move_compute();
|
||
/*==============================================================================================*/
|
||
// layernrom k quant k | store q
|
||
// 从qk的nram buffer中提取出k做layernorm和量化
|
||
float *nram_k = nram_qk_rot;
|
||
__memcpy(nram_k, nram_qk + task_deal_batch * q_hidden, head_size * float_size, NRAM2NRAM,
|
||
head_size * float_size, head_num_k - 1, k_hidden * float_size, task_deal_batch - 1,
|
||
task_deal_batch * head_size * float_size, head_num_k - 1, head_size * float_size,
|
||
task_deal_batch - 1);
|
||
layernormImpl(nram_k, norm_params, task_deal_batch * head_num_k, head_size, eps);
|
||
quantify<float>(nram_k, nram_k, nram_kv_hp, nram_kv_lp, nram_kv_trans, nram_k_scale_hp,
|
||
nram_k_scale_lp, nram_v_scale_lp /*lp_scale temp*/, task_deal_batch, head_num_k,
|
||
head_size, group_num, group_size, quant_kv_hp, mixed_cache);
|
||
if (mixed_cache) {
|
||
__bang_int82int4_rn((int4x2_t *)nram_kv_lp, nram_kv_lp, task_deal_batch * k_hidden, 0, 0);
|
||
}
|
||
if (!quant_kv_hp) {
|
||
floatTo<T>((T *)nram_k, nram_k, task_deal_batch * k_hidden);
|
||
}
|
||
|
||
// store q
|
||
__memcpy_async(input_begin, nram_qk, head_size * dtype_size, NRAM2GDRAM, qkv_hidden * dtype_size,
|
||
task_deal_batch - 1, head_size * dtype_size, head_num_q - 1,
|
||
head_size * dtype_size, task_deal_batch - 1,
|
||
task_deal_batch * head_size * dtype_size, head_num_q - 1);
|
||
// ===============================================================================================
|
||
|
||
int8_t *scatter_k_src = quant_kv_hp ? nram_kv_hp : (int8_t *)nram_k;
|
||
__scatter(key_cache_hp, scatter_k_src, (uint32_t *)nram_kv_cache_offsets_hp, cache_mask_hp,
|
||
head_size * kv_size_hp, NRAM2GDRAM, head_size * kv_size_hp,
|
||
head_num_k * task_deal_batch);
|
||
if (mixed_cache) {
|
||
__scatter(key_cache_lp, nram_kv_lp, (uint32_t *)nram_k_cache_offsets_lp, cache_mask_lp,
|
||
head_size / 2 * sizeof(int8_t), NRAM2GDRAM, head_size / 2 * sizeof(int8_t),
|
||
head_num_k * task_deal_batch);
|
||
__scatter(key_scale_lp, nram_k_scale_lp, (uint32_t *)nram_kv_scale_offsets, cache_mask_lp,
|
||
group_num * sizeof(float), NRAM2GDRAM, group_num * sizeof(float),
|
||
head_num_k * task_deal_batch);
|
||
}
|
||
__sync_io_move_compute();
|
||
#endif
|
||
}
|
||
|
||
template <typename T>
|
||
__mlu_global__ void MLUFuseRope(T *input,
|
||
void *key_cache_hp,
|
||
void *value_cache_hp,
|
||
void *key_cache_lp,
|
||
void *value_cache_lp,
|
||
T *sin_table,
|
||
T *cos_table,
|
||
int *rope_offsets,
|
||
T *gamma,
|
||
T *beta,
|
||
float *key_scale_hp,
|
||
float *value_scale_hp,
|
||
float *key_scale_lp,
|
||
float *value_scale_lp,
|
||
int *cache_bs_id_hp,
|
||
int *cache_seq_offsets_hp,
|
||
int *cache_bs_id_lp,
|
||
int *cache_seq_offsets_lp,
|
||
int *slot_mapping_hp,
|
||
int *slot_mapping_lp,
|
||
int rotary_stride,
|
||
int batch,
|
||
int head_num_q,
|
||
int head_num_k,
|
||
int head_size,
|
||
int max_decode_len_hp,
|
||
int max_decode_len_lp,
|
||
int block_size_hp,
|
||
int block_size_lp,
|
||
int group_size,
|
||
int batch_cap,
|
||
int task_avg_batch,
|
||
float eps) {
|
||
int task_begin_batch = taskId * task_avg_batch;
|
||
int task_deal_batch = std::min(batch - task_begin_batch, task_avg_batch);
|
||
if (task_deal_batch <= 0 || __is_mpu()) {
|
||
return;
|
||
}
|
||
|
||
int task_loop = (task_deal_batch + batch_cap - 1) / batch_cap;
|
||
int once_batch = (task_deal_batch + task_loop - 1) / task_loop;
|
||
|
||
for (int i = 0; i < task_loop; i++) {
|
||
int cur_batch = std::min(task_deal_batch - i * once_batch, once_batch);
|
||
int batch_offset = task_begin_batch + once_batch * i;
|
||
|
||
fuseRopeImpl<T>(input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp, sin_table,
|
||
cos_table, rope_offsets, gamma, beta, key_scale_hp, value_scale_hp,
|
||
key_scale_lp, value_scale_lp, cache_bs_id_hp, cache_seq_offsets_hp,
|
||
cache_bs_id_lp, cache_seq_offsets_lp, slot_mapping_hp, slot_mapping_lp,
|
||
rotary_stride, cur_batch, batch_offset, head_num_q, head_num_k, head_size,
|
||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size,
|
||
once_batch, eps);
|
||
}
|
||
}
|
||
} // namespace kernels
|
||
|
||
KernelStatus invokeFusedRope(cnrtQueue_t queue,
|
||
void *input,
|
||
void *key_cache_hp,
|
||
void *value_cache_hp,
|
||
void *key_cache_lp,
|
||
void *value_cache_lp,
|
||
const void *sin_table,
|
||
const void *cos_table,
|
||
const void *rope_offsets,
|
||
const void *gamma,
|
||
const void *beta,
|
||
const void *key_scale_hp,
|
||
const void *value_scale_hp,
|
||
void *key_scale_lp,
|
||
void *value_scale_lp,
|
||
const void *cache_bs_id_hp,
|
||
const void *cache_seq_offsets_hp,
|
||
const void *cache_bs_id_lp,
|
||
const void *cache_seq_offsets_lp,
|
||
const void *slot_mapping_hp,
|
||
const void *slot_mapping_lp,
|
||
int rotary_stride,
|
||
int batch_size,
|
||
int head_num_q,
|
||
int head_num_kv,
|
||
int head_size,
|
||
int max_decode_len_hp,
|
||
int max_decode_len_lp,
|
||
int block_size_hp,
|
||
int block_size_lp,
|
||
int group_size,
|
||
cnnlDataType_t dtype,
|
||
float eps) {
|
||
if (is_arch300()) {
|
||
std::cerr << "[invokeFusedRope]: kernel does not support MLU300 devices." << std::endl;
|
||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||
}
|
||
CNdev dev;
|
||
cnCtxGetDevice(&dev);
|
||
int cluster_num;
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||
int core_num;
|
||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||
uint32_t taskdimx = cluster_num * core_num;
|
||
|
||
int task_avg_batch = (batch_size + taskdimx - 1) / taskdimx;
|
||
int float_size = sizeof(float);
|
||
int group_num = head_size / group_size;
|
||
bool quant_kv_hp = key_scale_hp != nullptr && value_scale_hp != nullptr;
|
||
bool mixed_cache = key_cache_lp != nullptr && value_cache_lp != nullptr;
|
||
|
||
int nram_avalible_bytes = 480 * 1024;
|
||
int task_max_batch = 32;
|
||
int mask_bytes = PAD_UP(task_max_batch * head_num_kv, 8) / 8 * (mixed_cache + 1);
|
||
int nram_params_bytes = 2 * head_size * float_size;
|
||
int nram_kv_hp_scale_bytes = 2 * (int)quant_kv_hp * head_num_kv * head_size * float_size;
|
||
int nram_remain_bytes =
|
||
nram_avalible_bytes - nram_params_bytes - nram_kv_hp_scale_bytes - mask_bytes;
|
||
int nram_qk_bytes = (head_num_q + head_num_kv) * head_size * float_size * 2;
|
||
int nram_v_bytes = head_num_kv * head_size * float_size * (mixed_cache + 1);
|
||
int nram_table_bytes = 2 * head_size * float_size;
|
||
int nram_kv_lp_scale_bytes = 2 * (int)mixed_cache * head_num_kv * group_num * float_size;
|
||
int nram_kv_hp_bytes = (int)quant_kv_hp * head_num_kv * head_size;
|
||
int nram_kv_lp_bytes = (int)mixed_cache * head_num_kv * head_size;
|
||
int nram_cache_v_bytes = (int)mixed_cache * head_num_kv * head_size * 2;
|
||
int nram_cache_offsets_hp = head_num_kv * sizeof(int);
|
||
int nram_cache_offsets_lp = (int)mixed_cache * head_num_kv * 3 * sizeof(int);
|
||
int batch_cap =
|
||
nram_remain_bytes /
|
||
(nram_qk_bytes + nram_v_bytes + nram_table_bytes + nram_kv_lp_scale_bytes + nram_kv_hp_bytes +
|
||
nram_kv_lp_bytes + nram_cache_v_bytes + nram_cache_offsets_hp + nram_cache_offsets_lp);
|
||
batch_cap = batch_cap < task_avg_batch ? std::min(task_max_batch, batch_cap) : task_avg_batch;
|
||
|
||
cnrtDim3_t dim{taskdimx, 1, 1};
|
||
if (dtype == CNNL_DTYPE_HALF) {
|
||
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||
(half *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
|
||
(half *)sin_table, (half *)cos_table, (int *)rope_offsets, (half *)gamma, (half *)beta,
|
||
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
|
||
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
|
||
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
|
||
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
|
||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
|
||
task_avg_batch, eps);
|
||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||
kernels::MLUFuseRope<<<dim, cnrtFuncTypeBlock, queue>>>(
|
||
(bf16 *)input, key_cache_hp, value_cache_hp, key_cache_lp, value_cache_lp,
|
||
(bf16 *)sin_table, (bf16 *)cos_table, (int *)rope_offsets, (bf16 *)gamma, (bf16 *)beta,
|
||
(float *)key_scale_hp, (float *)value_scale_hp, (float *)key_scale_lp,
|
||
(float *)value_scale_lp, (int *)cache_bs_id_hp, (int *)cache_seq_offsets_hp,
|
||
(int *)cache_bs_id_lp, (int *)cache_seq_offsets_lp, (int *)slot_mapping_hp,
|
||
(int *)slot_mapping_lp, rotary_stride, batch_size, head_num_q, head_num_kv, head_size,
|
||
max_decode_len_hp, max_decode_len_lp, block_size_hp, block_size_lp, group_size, batch_cap,
|
||
task_avg_batch, eps);
|
||
}
|
||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||
}
|
||
} // namespace tmo
|