Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mlu
2026-02-04 17:39:32 +08:00

187 lines
8.2 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include <type_traits>
#include "cnnl.h"
#include "cnrt.h"
#include "quantize.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define NRAM_BUFFER_SIZE (480 * 1024)
namespace kernels {
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename TSrc>
__mlu_func__ void quantify(int8_t *nram_dst,
TSrc *nram_src,
TSrc *nram_scale_temp,
float *nram_scale,
float *scale_origin,
int core_deal_tokens,
int hidden) {
__bang_abs((TSrc *)nram_dst, nram_src, core_deal_tokens * hidden);
__bang_maxpool(nram_scale_temp, (TSrc *)nram_dst, core_deal_tokens, hidden, 1, hidden, 1, 1, 1);
if (std::is_same<half, TSrc>::value) {
__bang_half2float(nram_scale, (half *)nram_scale_temp, core_deal_tokens);
}
__bang_mul_scalar(nram_scale, nram_scale, 1 / 127.f, core_deal_tokens);
__bang_recip(scale_origin, nram_scale, core_deal_tokens);
if (std::is_same<half, TSrc>::value) {
__bang_float2half_rn((half *)scale_origin, scale_origin, core_deal_tokens);
}
__bang_cycle_mul(nram_src, nram_src, (TSrc *)scale_origin, core_deal_tokens * hidden,
core_deal_tokens);
if (std::is_same<half, TSrc>::value) {
__bang_half2int8_rn((int8_t *)nram_src, (half *)nram_src, core_deal_tokens * hidden, 0);
} else if (std::is_same<float, TSrc>::value) {
__bang_float2int8_rn((int8_t *)nram_src, (float *)nram_src, core_deal_tokens * hidden, 0);
}
__bang_transpose(nram_dst, (int8_t *)nram_src, hidden, core_deal_tokens);
}
template <typename TDst, typename TSrc, typename TScale>
__mlu_global__ void MLUQuantizePerHead(
TDst *dst, // [bs, seq, head_num, head_size], may not be continuous
TScale *scale, // [bs, seq], must becontinuous
const TSrc *src, // [bs, seq, head_num, head_size], may not be continuous
int bs,
int seq_len,
int head_num,
int head_size,
int src_bs_stride,
int src_seq_stride,
int src_head_stride,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride) {
int total_bs = bs * seq_len;
int hidden = head_num * head_size;
int core_average_tokens = (total_bs + taskDim - 1) / taskDim;
int core_begin_tokens = core_average_tokens * taskId;
int core_deal_tokens = std::min(total_bs - core_begin_tokens, core_average_tokens);
if (__is_mpu()) {
return;
}
if (core_deal_tokens <= 0) {
return;
}
TScale *nram_scale = (TScale *)nram_buffer;
TScale *scale_origin = nram_scale + core_deal_tokens * head_num;
TSrc *nram_scale_temp =
(TSrc *)(nram_buffer + core_deal_tokens * head_num * (sizeof(TScale) - sizeof(TSrc)));
TSrc *nram_ping = (TSrc *)(scale_origin + core_deal_tokens * head_num);
TSrc *nram_temp = nram_ping + core_deal_tokens * hidden;
const TSrc *src_begin = src + core_begin_tokens * src_seq_stride;
TDst *dst_begin = dst + core_begin_tokens * dst_seq_stride;
TScale *scale_begin = scale + core_begin_tokens * head_num;
// load
__memcpy(nram_ping, src_begin, head_size * sizeof(TSrc), GDRAM2NRAM, head_size * sizeof(TSrc),
head_num - 1, hidden * sizeof(TSrc), core_deal_tokens - 1,
src_head_stride * sizeof(TSrc), head_num - 1, src_seq_stride * sizeof(TSrc),
core_deal_tokens - 1);
__bang_transpose(nram_temp, nram_ping, core_deal_tokens * head_num, head_size);
quantify((TDst *)nram_ping, nram_temp, nram_scale_temp, nram_scale, scale_origin,
core_deal_tokens * head_num, head_size);
// store scale
__memcpy(scale_begin, nram_scale, core_deal_tokens * head_num * sizeof(TScale), NRAM2GDRAM);
// store
__memcpy(dst_begin, nram_ping, head_size * sizeof(TDst), NRAM2GDRAM,
dst_head_stride * sizeof(TDst), head_num - 1, dst_seq_stride * sizeof(TDst),
core_deal_tokens - 1, head_size * sizeof(TDst), head_num - 1, hidden * sizeof(TDst),
core_deal_tokens - 1);
}
} // namespace kernels
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
void *dst,
void *scale,
const void *src,
cnnlDataType_t dst_dtype,
cnnlDataType_t scale_dtype,
cnnlDataType_t src_dtype,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride) {
// bs must be continuous, for pack mode, bs = 1, seq_len equals to sum of all bs seq_len.
if (dst_bs_stride != seq_len * dst_seq_stride) {
std::cerr
<< "[invokeMluQuantizePerToken]: dst_bs_stride must equal to seq_len * dst_seq_stride."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (dst_head_stride != head_size) {
std::cerr << "[invokeMluQuantizePerToken]: dst_head_stride must equal to head_size."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (src_bs_stride != seq_len * src_seq_stride) {
std::cerr
<< "[invokeMluQuantizePerToken]: src_bs_stride must equal to seq_len * src_seq_stride."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (src_head_stride != head_size) {
std::cerr << "[invokeMluQuantizePerToken]: src_head_stride must equal to head_size."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
int scale_buffer_size = 64 * 1024; // scale on nram
int dtype_size = (src_dtype == CNNL_DTYPE_HALF || src_dtype == CNNL_DTYPE_BFLOAT16) ? 2 : 4;
int bs_once = (NRAM_BUFFER_SIZE - scale_buffer_size) / (2 * head_num * head_size * dtype_size);
int bs_once_ = scale_buffer_size / 2 / sizeof(float);
bs_once = std::min(bs_once, bs_once_);
uint32_t task_dim = std::min(bs * seq_len, cluster_num * core_num);
task_dim = std::max((uint32_t)(bs * seq_len + bs_once - 1) / bs_once, task_dim);
cnrtDim3_t dim{task_dim, 1, 1};
if (src_dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)dst, (float *)scale, (const float *)src, bs, seq_len, head_num, head_size,
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
dst_head_stride);
} else if (src_dtype == CNNL_DTYPE_HALF) {
kernels::MLUQuantizePerHead<<<dim, cnrtFuncTypeBlock, queue>>>(
(int8_t *)dst, (float *)scale, (const half *)src, bs, seq_len, head_num, head_size,
src_bs_stride, src_seq_stride, src_head_stride, dst_bs_stride, dst_seq_stride,
dst_head_stride);
} else if (src_dtype == CNNL_DTYPE_BFLOAT16) {
std::cerr << __func__ << "," << __LINE__ << " :currently does not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo