/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include #include #include #include "cnnl.h" #include "cnrt.h" #include "quantize.mluh" // clang-format off #include // clang-format on namespace tmo { #define NRAM_BUFFER_SIZE (480 * 1024) namespace kernels { __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; template __mlu_func__ void quantify(int8_t *nram_dst, TSrc *nram_src, TSrc *nram_scale_temp, float *nram_scale, float *scale_origin, int core_deal_tokens, int hidden) { __bang_abs((TSrc *)nram_dst, nram_src, core_deal_tokens * hidden); __bang_maxpool(nram_scale_temp, (TSrc *)nram_dst, core_deal_tokens, hidden, 1, hidden, 1, 1, 1); if (std::is_same::value) { __bang_half2float(nram_scale, (half *)nram_scale_temp, core_deal_tokens); } __bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, core_deal_tokens); __bang_recip(scale_origin, nram_scale, core_deal_tokens); if (std::is_same::value) { __bang_float2half_rn((half *)scale_origin, scale_origin, core_deal_tokens); } __bang_cycle_mul(nram_src, nram_src, (TSrc *)scale_origin, core_deal_tokens * hidden, core_deal_tokens); if (std::is_same::value) { __bang_half2int8_rn((int8_t *)nram_src, (half *)nram_src, core_deal_tokens * hidden, 0); } else if (std::is_same::value) { __bang_float2int8_rn((int8_t *)nram_src, (float *)nram_src, core_deal_tokens * hidden, 0); } __bang_transpose(nram_dst, (int8_t *)nram_src, hidden, core_deal_tokens); } template __mlu_global__ void MLUQuantizePerHead( TDst *dst, // [bs, seq, head_num, head_size], may not be continuous TScale *scale, // [bs, seq], must becontinuous const TSrc *src, // [bs, seq, head_num, head_size], may not be continuous int bs, int seq_len, int head_num, int head_size, int src_bs_stride, int src_seq_stride, int src_head_stride, int dst_bs_stride, int dst_seq_stride, int dst_head_stride) { int total_bs = bs * seq_len; int hidden = head_num * head_size; int core_average_tokens = (total_bs + taskDim - 1) / taskDim; int core_begin_tokens = core_average_tokens * taskId; int core_deal_tokens = std::min(total_bs - core_begin_tokens, core_average_tokens); if (__is_mpu()) { return; } if (core_deal_tokens <= 0) { return; } TScale *nram_scale = (TScale *)nram_buffer; TScale *scale_origin = nram_scale + core_deal_tokens * head_num; TSrc *nram_scale_temp = (TSrc *)(nram_buffer + core_deal_tokens * head_num * (sizeof(TScale) - sizeof(TSrc))); TSrc *nram_ping = (TSrc *)(scale_origin + core_deal_tokens * head_num); TSrc *nram_temp = nram_ping + core_deal_tokens * hidden; const TSrc *src_begin = src + core_begin_tokens * src_seq_stride; TDst *dst_begin = dst + core_begin_tokens * dst_seq_stride; TScale *scale_begin = scale + core_begin_tokens * head_num; // load __memcpy(nram_ping, src_begin, head_size * sizeof(TSrc), GDRAM2NRAM, head_size * sizeof(TSrc), head_num - 1, hidden * sizeof(TSrc), core_deal_tokens - 1, src_head_stride * sizeof(TSrc), head_num - 1, src_seq_stride * sizeof(TSrc), core_deal_tokens - 1); __bang_transpose(nram_temp, nram_ping, core_deal_tokens * head_num, head_size); quantify((TDst *)nram_ping, nram_temp, nram_scale_temp, nram_scale, scale_origin, core_deal_tokens * head_num, head_size); // store scale __memcpy(scale_begin, nram_scale, core_deal_tokens * head_num * sizeof(TScale), NRAM2GDRAM); // store __memcpy(dst_begin, nram_ping, head_size * sizeof(TDst), NRAM2GDRAM, dst_head_stride * sizeof(TDst), head_num - 1, dst_seq_stride * sizeof(TDst), core_deal_tokens - 1, head_size * sizeof(TDst), head_num - 1, hidden * sizeof(TDst), core_deal_tokens - 1); } } // namespace kernels KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue, void *dst, void *scale, const void *src, cnnlDataType_t dst_dtype, cnnlDataType_t scale_dtype, cnnlDataType_t src_dtype, int bs, int seq_len, int head_num, int head_size, int dst_bs_stride, int dst_seq_stride, int dst_head_stride, int src_bs_stride, int src_seq_stride, int src_head_stride) { // bs must be continuous, for pack mode, bs = 1, seq_len equals to sum of all bs seq_len. if (dst_bs_stride != seq_len * dst_seq_stride) { std::cerr << "[invokeMluQuantizePerToken]: dst_bs_stride must equal to seq_len * dst_seq_stride." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (dst_head_stride != head_size) { std::cerr << "[invokeMluQuantizePerToken]: dst_head_stride must equal to head_size." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (src_bs_stride != seq_len * src_seq_stride) { std::cerr << "[invokeMluQuantizePerToken]: src_bs_stride must equal to seq_len * src_seq_stride." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (src_head_stride != head_size) { std::cerr << "[invokeMluQuantizePerToken]: src_head_stride must equal to head_size." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } CNdev dev; cnCtxGetDevice(&dev); int cluster_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); int scale_buffer_size = 64 * 1024; // scale on nram int dtype_size = (src_dtype == CNNL_DTYPE_HALF || src_dtype == CNNL_DTYPE_BFLOAT16) ? 2 : 4; int bs_once = (NRAM_BUFFER_SIZE - scale_buffer_size) / (2 * head_num * head_size * dtype_size); int bs_once_ = scale_buffer_size / 2 / sizeof(float); bs_once = std::min(bs_once, bs_once_); uint32_t task_dim = std::min(bs * seq_len, cluster_num * core_num); task_dim = std::max((uint32_t)(bs * seq_len + bs_once - 1) / bs_once, task_dim); cnrtDim3_t dim{task_dim, 1, 1}; if (src_dtype == CNNL_DTYPE_FLOAT) { kernels::MLUQuantizePerHead<<>>( (int8_t *)dst, (float *)scale, (const float *)src, bs, seq_len, head_num, head_size, src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride, dst_head_stride); } else if (src_dtype == CNNL_DTYPE_HALF) { kernels::MLUQuantizePerHead<<>>( (int8_t *)dst, (float *)scale, (const half *)src, bs, seq_len, head_num, head_size, src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride, dst_head_stride); } else if (src_dtype == CNNL_DTYPE_BFLOAT16) { std::cerr << __func__ << "," << __LINE__ << " :currently does not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo