187 lines
8.2 KiB
Plaintext
187 lines
8.2 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <type_traits>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "quantize.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
#define NRAM_BUFFER_SIZE (480 * 1024)
|
|
namespace kernels {
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
|
|
template <typename TSrc>
|
|
__mlu_func__ void quantify(int8_t *nram_dst,
|
|
TSrc *nram_src,
|
|
TSrc *nram_scale_temp,
|
|
float *nram_scale,
|
|
float *scale_origin,
|
|
int core_deal_tokens,
|
|
int hidden) {
|
|
__bang_abs((TSrc *)nram_dst, nram_src, core_deal_tokens * hidden);
|
|
__bang_maxpool(nram_scale_temp, (TSrc *)nram_dst, core_deal_tokens, hidden, 1, hidden, 1, 1, 1);
|
|
if (std::is_same<half, TSrc>::value) {
|
|
__bang_half2float(nram_scale, (half *)nram_scale_temp, core_deal_tokens);
|
|
}
|
|
|
|
__bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, core_deal_tokens);
|
|
__bang_recip(scale_origin, nram_scale, core_deal_tokens);
|
|
if (std::is_same<half, TSrc>::value) {
|
|
__bang_float2half_rn((half *)scale_origin, scale_origin, core_deal_tokens);
|
|
}
|
|
__bang_cycle_mul(nram_src, nram_src, (TSrc *)scale_origin, core_deal_tokens * hidden,
|
|
core_deal_tokens);
|
|
|
|
if (std::is_same<half, TSrc>::value) {
|
|
__bang_half2int8_rn((int8_t *)nram_src, (half *)nram_src, core_deal_tokens * hidden, 0);
|
|
} else if (std::is_same<float, TSrc>::value) {
|
|
__bang_float2int8_rn((int8_t *)nram_src, (float *)nram_src, core_deal_tokens * hidden, 0);
|
|
}
|
|
__bang_transpose(nram_dst, (int8_t *)nram_src, hidden, core_deal_tokens);
|
|
}
|
|
|
|
template <typename TDst, typename TSrc, typename TScale>
|
|
__mlu_global__ void MLUQuantizePerHead(
|
|
TDst *dst, // [bs, seq, head_num, head_size], may not be continuous
|
|
TScale *scale, // [bs, seq], must becontinuous
|
|
const TSrc *src, // [bs, seq, head_num, head_size], may not be continuous
|
|
int bs,
|
|
int seq_len,
|
|
int head_num,
|
|
int head_size,
|
|
int src_bs_stride,
|
|
int src_seq_stride,
|
|
int src_head_stride,
|
|
int dst_bs_stride,
|
|
int dst_seq_stride,
|
|
int dst_head_stride) {
|
|
int total_bs = bs * seq_len;
|
|
int hidden = head_num * head_size;
|
|
int core_average_tokens = (total_bs + taskDim - 1) / taskDim;
|
|
int core_begin_tokens = core_average_tokens * taskId;
|
|
int core_deal_tokens = std::min(total_bs - core_begin_tokens, core_average_tokens);
|
|
if (__is_mpu()) {
|
|
return;
|
|
}
|
|
if (core_deal_tokens <= 0) {
|
|
return;
|
|
}
|
|
|
|
TScale *nram_scale = (TScale *)nram_buffer;
|
|
TScale *scale_origin = nram_scale + core_deal_tokens * head_num;
|
|
TSrc *nram_scale_temp =
|
|
(TSrc *)(nram_buffer + core_deal_tokens * head_num * (sizeof(TScale) - sizeof(TSrc)));
|
|
TSrc *nram_ping = (TSrc *)(scale_origin + core_deal_tokens * head_num);
|
|
TSrc *nram_temp = nram_ping + core_deal_tokens * hidden;
|
|
const TSrc *src_begin = src + core_begin_tokens * src_seq_stride;
|
|
TDst *dst_begin = dst + core_begin_tokens * dst_seq_stride;
|
|
TScale *scale_begin = scale + core_begin_tokens * head_num;
|
|
|
|
// load
|
|
__memcpy(nram_ping, src_begin, head_size * sizeof(TSrc), GDRAM2NRAM, head_size * sizeof(TSrc),
|
|
head_num - 1, hidden * sizeof(TSrc), core_deal_tokens - 1,
|
|
src_head_stride * sizeof(TSrc), head_num - 1, src_seq_stride * sizeof(TSrc),
|
|
core_deal_tokens - 1);
|
|
|
|
__bang_transpose(nram_temp, nram_ping, core_deal_tokens * head_num, head_size);
|
|
quantify((TDst *)nram_ping, nram_temp, nram_scale_temp, nram_scale, scale_origin,
|
|
core_deal_tokens * head_num, head_size);
|
|
// store scale
|
|
__memcpy(scale_begin, nram_scale, core_deal_tokens * head_num * sizeof(TScale), NRAM2GDRAM);
|
|
// store
|
|
__memcpy(dst_begin, nram_ping, head_size * sizeof(TDst), NRAM2GDRAM,
|
|
dst_head_stride * sizeof(TDst), head_num - 1, dst_seq_stride * sizeof(TDst),
|
|
core_deal_tokens - 1, head_size * sizeof(TDst), head_num - 1, hidden * sizeof(TDst),
|
|
core_deal_tokens - 1);
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
|
|
void *dst,
|
|
void *scale,
|
|
const void *src,
|
|
cnnlDataType_t dst_dtype,
|
|
cnnlDataType_t scale_dtype,
|
|
cnnlDataType_t src_dtype,
|
|
int bs,
|
|
int seq_len,
|
|
int head_num,
|
|
int head_size,
|
|
int dst_bs_stride,
|
|
int dst_seq_stride,
|
|
int dst_head_stride,
|
|
int src_bs_stride,
|
|
int src_seq_stride,
|
|
int src_head_stride) {
|
|
// bs must be continuous, for pack mode, bs = 1, seq_len equals to sum of all bs seq_len.
|
|
if (dst_bs_stride != seq_len * dst_seq_stride) {
|
|
std::cerr
|
|
<< "[invokeMluQuantizePerToken]: dst_bs_stride must equal to seq_len * dst_seq_stride."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (dst_head_stride != head_size) {
|
|
std::cerr << "[invokeMluQuantizePerToken]: dst_head_stride must equal to head_size."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (src_bs_stride != seq_len * src_seq_stride) {
|
|
std::cerr
|
|
<< "[invokeMluQuantizePerToken]: src_bs_stride must equal to seq_len * src_seq_stride."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (src_head_stride != head_size) {
|
|
std::cerr << "[invokeMluQuantizePerToken]: src_head_stride must equal to head_size."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
|
|
int scale_buffer_size = 64 * 1024; // scale on nram
|
|
int dtype_size = (src_dtype == CNNL_DTYPE_HALF || src_dtype == CNNL_DTYPE_BFLOAT16) ? 2 : 4;
|
|
int bs_once = (NRAM_BUFFER_SIZE - scale_buffer_size) / (2 * head_num * head_size * dtype_size);
|
|
int bs_once_ = scale_buffer_size / 2 / sizeof(float);
|
|
bs_once = std::min(bs_once, bs_once_);
|
|
uint32_t task_dim = std::min(bs * seq_len, cluster_num * core_num);
|
|
task_dim = std::max((uint32_t)(bs * seq_len + bs_once - 1) / bs_once, task_dim);
|
|
cnrtDim3_t dim{task_dim, 1, 1};
|
|
if (src_dtype == CNNL_DTYPE_FLOAT) {
|
|
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(int8_t *)dst, (float *)scale, (const float *)src, bs, seq_len, head_num, head_size,
|
|
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
|
|
dst_head_stride);
|
|
} else if (src_dtype == CNNL_DTYPE_HALF) {
|
|
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(int8_t *)dst, (float *)scale, (const half *)src, bs, seq_len, head_num, head_size,
|
|
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
|
|
dst_head_stride);
|
|
} else if (src_dtype == CNNL_DTYPE_BFLOAT16) {
|
|
std::cerr << __func__ << "," << __LINE__ << " :currently does not support bfloat16."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
} // namespace tmo
|