761 lines
36 KiB
Plaintext
761 lines
36 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include "cnrt.h"
|
|
#include "combine_result.mluh"
|
|
// clang-format off
|
|
#include <bang_device_functions_extra.h>
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
#if __BANG_ARCH__ >= 592
|
|
#include <bang_fusor.h>
|
|
template <typename SrcT>
|
|
using bang_fusor = bang::experimental::fusor<SrcT>;
|
|
#endif
|
|
|
|
namespace tmo {
|
|
namespace kernels {
|
|
#define NRAM_REMAIN_SIZE (32 * 1024)
|
|
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
|
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
|
|
template <typename T>
|
|
__mlu_func__ void swap(T *&ping, T *&pong) {
|
|
T *temp = ping;
|
|
ping = pong;
|
|
pong = temp;
|
|
}
|
|
|
|
#define GATHER_ASYNC_IO0(offset_type) \
|
|
__asm__ __volatile__( \
|
|
"gather.vector.async.nram.gdram.nram." #offset_type \
|
|
".io0 [%[dst]], [%[src]], [%[offset]], " \
|
|
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
|
|
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
|
|
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
|
|
|
|
#define FUSE_MUL_CVT(dst_dtype) \
|
|
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
|
|
".f32 [%[dst]], [%[src0]], %[src1]," \
|
|
" %[size];\n\t" ::[dst] "r"(dst), \
|
|
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
|
|
|
#define FUSE_MULADD_CVT(dst_dtype) \
|
|
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
|
|
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
|
|
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
|
|
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
|
|
|
template <typename T>
|
|
__mlu_func__ void toFloat(float *dst, T *src, int count) {
|
|
if (std::is_same<T, half>::value) {
|
|
__bang_half2float(dst, (half *)src, count);
|
|
} else if (std::is_same<T, bfloat16_t>::value) {
|
|
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
|
|
} else if (std::is_same<T, float>::value) {
|
|
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void floatTo(T *dst, float *src, int count) {
|
|
if (std::is_same<T, half>::value) {
|
|
__bang_float2half_rn((half *)dst, src, count);
|
|
} else if (std::is_same<T, bfloat16_t>::value) {
|
|
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
|
|
} else if (std::is_same<T, float>::value) {
|
|
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
|
}
|
|
}
|
|
|
|
__mlu_func__ void loadAsync2d(void *dst,
|
|
void *src,
|
|
int size,
|
|
int dststride,
|
|
int srcstride,
|
|
int seg_num) {
|
|
#if __BANG_ARCH__ > 500
|
|
__asm__ __volatile__(
|
|
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
|
|
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
|
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
|
[segnum] "r"(seg_num));
|
|
#else
|
|
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ void storeAsync2d(void *dst,
|
|
void *src,
|
|
int size,
|
|
int dststride,
|
|
int srcstride,
|
|
int seg_num) {
|
|
#if __BANG_ARCH__ > 500
|
|
__asm__ __volatile__(
|
|
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
|
|
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
|
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
|
[segnum] "r"(seg_num));
|
|
#else
|
|
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
|
|
#endif
|
|
}
|
|
|
|
template <typename T_IDX>
|
|
__mlu_func__ void gatherTokensAsync(void *dst,
|
|
void *src_gdram,
|
|
T_IDX *nram_offset,
|
|
int transfer_size,
|
|
int token_count) {
|
|
if (token_count <= 0 || src_gdram == nullptr) return;
|
|
#if __BANG_ARCH__ > 500
|
|
if (std::is_same<T_IDX, uint32_t>::value) {
|
|
GATHER_ASYNC_IO0(u32);
|
|
} else {
|
|
GATHER_ASYNC_IO0(u64);
|
|
}
|
|
#else
|
|
for (int k = 0; k < token_count; k++) {
|
|
__memcpy_async((int8_t *)dst + k * transfer_size,
|
|
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
|
|
int *nram_mask,
|
|
uint8_t *nram_mask_char,
|
|
int *nram_mask_buffer,
|
|
int begin_expert_acc_tokens,
|
|
int end_expert_acc_tokens,
|
|
int token_count,
|
|
bool expert_parallelism) {
|
|
if (!expert_parallelism) {
|
|
return token_count;
|
|
}
|
|
|
|
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
|
|
#if __BANG_ARCH__ >= 592
|
|
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
|
|
.ge(begin_expert_acc_tokens)
|
|
.land(nram_mask_buffer)
|
|
.cvt<float>(0);
|
|
#else
|
|
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
|
|
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
|
|
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
|
|
#endif
|
|
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
|
|
int active_token_count = __bang_count((float *)nram_mask, token_count);
|
|
return active_token_count;
|
|
}
|
|
|
|
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
|
|
int *nram_idx,
|
|
uint64_t mul_scalar,
|
|
int64_t add_scalar,
|
|
uint32_t token_count) {
|
|
#if __BANG_ARCH__ > 592
|
|
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
|
|
#else
|
|
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
|
|
#endif
|
|
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
|
|
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
|
|
}
|
|
|
|
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
|
|
int *nram_idx,
|
|
uint32_t mul_scalar,
|
|
int64_t add_scalar,
|
|
uint32_t token_count) {
|
|
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
|
|
token_count);
|
|
}
|
|
|
|
template <typename T_IDX>
|
|
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
|
|
T_IDX *nram_bias_offset,
|
|
int *nram_token_idx,
|
|
int *nram_expert_tables,
|
|
int expert_num,
|
|
int token_count,
|
|
int active_token_count,
|
|
int hidden_size,
|
|
int local_hidden_begin,
|
|
int dtype_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
int begin_expert_acc_tokens,
|
|
bool has_bias) {
|
|
// for large tensor, convert int322int64 then do multiply and add seperately.
|
|
if (active_token_count <= 0) return;
|
|
if (has_bias) {
|
|
int *nram_bias_offset_temp = (int *)nram_token_offset;
|
|
__bang_write_zero(nram_bias_offset, active_token_count);
|
|
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
|
|
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
|
|
active_token_count);
|
|
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
|
|
active_token_count);
|
|
}
|
|
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
|
|
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
|
|
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
|
|
}
|
|
|
|
int64_t offset =
|
|
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
|
|
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
|
|
active_token_count);
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
|
#if __BANG_ARCH__ > 500
|
|
if (std::is_same<T, bfloat16_t>::value) {
|
|
FUSE_MUL_CVT(bf16);
|
|
} else if (std::is_same<T, half>::value) {
|
|
FUSE_MUL_CVT(f16);
|
|
} else if (std::is_same<T, float>::value) {
|
|
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
|
}
|
|
#else
|
|
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
|
floatTo((T *)dst, (float *)dst, size);
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
|
#if __BANG_ARCH__ > 500
|
|
if (std::is_same<T, bfloat16_t>::value) {
|
|
FUSE_MULADD_CVT(bf16);
|
|
} else if (std::is_same<T, half>::value) {
|
|
FUSE_MULADD_CVT(f16);
|
|
} else if (std::is_same<T, float>::value) {
|
|
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
|
size);
|
|
}
|
|
#else
|
|
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
|
size);
|
|
floatTo((T *)dst, (float *)dst, size);
|
|
#endif
|
|
}
|
|
|
|
// weightedReduceSum with EP split
|
|
// input [token_count, k, hidden_size], weight [token_count, k]
|
|
// 1. input * weight
|
|
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
|
template <typename T,
|
|
bool expert_parallelism,
|
|
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
|
|
__mlu_func__ void weightedReduceSum(T *output,
|
|
T *input,
|
|
float *weight,
|
|
T *input_buffer,
|
|
int8_t *og_mask,
|
|
int topk,
|
|
int hidden_size,
|
|
int token_count,
|
|
bool &is_ping) {
|
|
float *nram_input_buffer =
|
|
(float *)((half *)input_buffer +
|
|
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
|
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
|
int32_t index[32];
|
|
float reg_weight[128];
|
|
int8_t *index_ = (int8_t *)index;
|
|
int topk_divide_4 = PAD_UP(topk, 4) / 4;
|
|
int token_use_count = 0;
|
|
for (int t_i = 0; t_i < token_count; t_i++) {
|
|
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
|
for (int i = 0; i < topk_divide_4; i++) {
|
|
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
|
|
float *weight_begin = weight + t_i * topk + i * 4;
|
|
reg_weight[i * 4] = __load_nram(weight_begin);
|
|
if (i * 4 + 1 < topk) {
|
|
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
|
|
}
|
|
if (i * 4 + 2 < topk) {
|
|
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
|
|
}
|
|
if (i * 4 + 3 < topk) {
|
|
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
|
|
}
|
|
}
|
|
|
|
int first_in_expert = 0;
|
|
float expert_coeff;
|
|
for (; first_in_expert < topk - 1; first_in_expert++) {
|
|
bool in_expert_range = index_[first_in_expert];
|
|
if (!in_expert_range) continue;
|
|
|
|
expert_coeff = reg_weight[first_in_expert];
|
|
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
|
|
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
|
token_use_count++;
|
|
break;
|
|
}
|
|
if (first_in_expert == topk - 1) {
|
|
if (index_[topk - 1]) {
|
|
expert_coeff = reg_weight[topk - 1];
|
|
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
|
token_use_count++;
|
|
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
|
} else {
|
|
__bang_write_zero((T *)output_begin, hidden_size);
|
|
}
|
|
} else {
|
|
for (int j = first_in_expert + 1; j < topk - 1; j++) {
|
|
bool in_expert_range = index_[j];
|
|
if (!in_expert_range) continue;
|
|
|
|
expert_coeff = reg_weight[j];
|
|
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
|
token_use_count++;
|
|
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
|
hidden_size, hidden_size);
|
|
}
|
|
if (index_[topk - 1]) {
|
|
expert_coeff = reg_weight[topk - 1];
|
|
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
|
token_use_count++;
|
|
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
|
} else {
|
|
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
|
|
}
|
|
}
|
|
}
|
|
if (!is_ping && sizeof(T) < sizeof(float)) {
|
|
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
|
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
|
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
|
token_count * hidden_size * sizeof(T), 0);
|
|
}
|
|
is_ping = !is_ping;
|
|
}
|
|
|
|
// weightedReduceSum without EP split
|
|
// input [token_count, k, hidden_size], weight [token_count, k]
|
|
// 1. input * weight
|
|
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
|
template <typename T,
|
|
bool expert_parallelism,
|
|
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
|
|
__mlu_func__ void weightedReduceSum(T *output,
|
|
T *input,
|
|
float *weight,
|
|
T *input_buffer,
|
|
int8_t *og_mask,
|
|
int topk,
|
|
int hidden_size,
|
|
int token_count,
|
|
bool &is_ping) {
|
|
float *nram_input_buffer =
|
|
(float *)((half *)input_buffer +
|
|
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
|
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
|
if (topk == 1) {
|
|
for (int i = 0; i < token_count; i++) {
|
|
float expert_coeff = __load_nram(weight + i);
|
|
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
|
|
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
|
|
}
|
|
return;
|
|
}
|
|
for (int t_i = 0; t_i < token_count; t_i++) {
|
|
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
|
float expert_coeff = __load_nram(weight + t_i * topk);
|
|
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
|
|
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
|
|
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
|
expert_coeff = __load_nram(weight + t_i * topk + 1);
|
|
for (int k_i = 2; k_i < topk; k_i++) {
|
|
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
|
hidden_size, hidden_size);
|
|
expert_coeff = __load_nram(weight + t_i * topk + k_i);
|
|
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
|
|
}
|
|
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
|
}
|
|
if (!is_ping && sizeof(T) < sizeof(float)) {
|
|
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
|
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
|
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
|
token_count * hidden_size * sizeof(T), 0);
|
|
}
|
|
is_ping = !is_ping;
|
|
}
|
|
|
|
template <typename T, typename T_IDX>
|
|
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
|
|
T *input,
|
|
T *bias,
|
|
T *residual,
|
|
float *reduce_weight,
|
|
int *cusum_token_count,
|
|
int *gather_idx,
|
|
int num_token,
|
|
int topk,
|
|
int num_expert,
|
|
int hidden_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
int HIDDEN_BLOCK,
|
|
int TOKEN_BLOCK) {
|
|
if (__is_mpu()) {
|
|
return;
|
|
}
|
|
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
|
|
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
|
|
int task_avg_tokens = num_token / taskDimY;
|
|
int task_remain_tokens = num_token % taskDimY;
|
|
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
|
|
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
|
|
if (local_hidden_size <= 0) return;
|
|
if (task_tokens <= 0) return;
|
|
|
|
constexpr int int32_dtype_size = (int)sizeof(int);
|
|
constexpr int fp32_dtype_size = (int)sizeof(float);
|
|
int pad_num_expert = PAD_UP(num_expert + 1, 32);
|
|
bool has_bias = bias != nullptr;
|
|
bool has_residual = residual != nullptr;
|
|
bool using_acc_sum = cusum_token_count != nullptr;
|
|
bool expert_parallelism = expert_size < num_expert;
|
|
|
|
int block_size = TOKEN_BLOCK * topk;
|
|
int pad_block_size = PAD_UP(block_size, 64);
|
|
|
|
int *nram_expert_tables = (int *)nram_buffer;
|
|
int *nram_token_idx = nram_expert_tables + pad_num_expert;
|
|
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
|
|
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
|
|
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
|
|
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
|
|
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
|
|
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
|
|
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
|
|
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
|
|
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
|
|
float *nram_weight_ping =
|
|
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
|
|
float *nram_weight_pong = nram_weight_ping + pad_block_size;
|
|
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
|
|
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
|
|
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
|
|
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
|
|
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
|
|
int *nram_mask_buffer = (int *)nram_token_offset;
|
|
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
|
|
|
|
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
|
|
int begin_expert_acc_tokens = 0;
|
|
int end_expert_acc_tokens = num_token * topk;
|
|
if (using_acc_sum) {
|
|
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
|
|
GDRAM2NRAM);
|
|
}
|
|
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
|
|
init_token_count * sizeof(int), GDRAM2NRAM);
|
|
__sync_io();
|
|
|
|
if (expert_parallelism) {
|
|
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
|
|
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
|
|
}
|
|
|
|
int active_token_count = getMaskAndActiveTokenCount(
|
|
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
|
end_expert_acc_tokens, init_token_count, expert_parallelism);
|
|
|
|
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
|
|
init_token_count, active_token_count, hidden_size, local_hidden_begin,
|
|
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
|
|
__sync_io_move_compute(true, false, false, false, false, true);
|
|
__sync_io_move_compute(false, false, true, true, false, false);
|
|
|
|
int next_active_token_count = active_token_count;
|
|
int previous_global_token_begin = 0;
|
|
int previous_token_count = 0;
|
|
bool is_ping = false;
|
|
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
|
|
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
|
|
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
|
|
bool is_last_loop = next_token_begin >= task_tokens;
|
|
bool is_last_2_loop = next_next_token_begin >= task_tokens;
|
|
int current_token_begin = task_begin * TOKEN_BLOCK;
|
|
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
|
|
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
|
|
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
|
|
int current_global_token_begin = task_token_begin + current_token_begin;
|
|
int next_global_token_begin = task_token_begin + next_token_begin;
|
|
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
|
|
|
|
if (!is_last_loop) {
|
|
if (!is_last_2_loop) {
|
|
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
|
|
next_next_token_count * topk * sizeof(int), 0, 0, 0);
|
|
}
|
|
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
|
|
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
|
|
if (has_residual) {
|
|
loadAsync2d(nram_residual_ping,
|
|
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
|
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
|
|
hidden_size * sizeof(T), next_token_count - 1);
|
|
}
|
|
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
|
|
local_hidden_size * sizeof(T), next_active_token_count);
|
|
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
|
|
local_hidden_size * sizeof(T), next_active_token_count);
|
|
}
|
|
|
|
if (task_begin >= 1) {
|
|
storeAsync2d(
|
|
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
|
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
|
local_hidden_size * sizeof(T), previous_token_count - 1);
|
|
}
|
|
|
|
if (task_begin >= 0) {
|
|
if (has_bias && active_token_count) {
|
|
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
|
|
active_token_count * local_hidden_size);
|
|
}
|
|
if (expert_parallelism) {
|
|
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
|
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
|
local_hidden_size, current_token_count, is_ping);
|
|
} else {
|
|
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
|
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
|
local_hidden_size, current_token_count, is_ping);
|
|
}
|
|
if (has_residual) {
|
|
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
|
|
current_token_count * local_hidden_size);
|
|
}
|
|
}
|
|
|
|
__sync_io_move_compute();
|
|
active_token_count = next_active_token_count;
|
|
if (expert_parallelism && !is_last_loop) {
|
|
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
|
|
}
|
|
if (!is_last_2_loop) {
|
|
next_active_token_count = getMaskAndActiveTokenCount(
|
|
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
|
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
|
|
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
|
|
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
|
|
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
|
|
begin_expert_acc_tokens, has_bias);
|
|
}
|
|
|
|
swap(nram_input_ping, nram_input_pong);
|
|
swap(nram_bias_ping, nram_bias_pong);
|
|
swap(nram_residual_ping, nram_residual_pong);
|
|
swap(nram_weight_ping, nram_weight_pong);
|
|
swap(nram_output_ping, nram_output_pong);
|
|
previous_global_token_begin = current_global_token_begin;
|
|
previous_token_count = current_token_count;
|
|
}
|
|
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
|
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
|
local_hidden_size * sizeof(T), previous_token_count - 1);
|
|
}
|
|
|
|
#if __BANG_ARCH__ < 500
|
|
template <>
|
|
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
|
|
bfloat16_t *input,
|
|
bfloat16_t *bias,
|
|
bfloat16_t *residual,
|
|
float *reduce_weight,
|
|
int *cusum_token_count,
|
|
int *gather_ids,
|
|
int num_token,
|
|
int topk,
|
|
int num_expert,
|
|
int hidden_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
int HIDDEN_BLOCK,
|
|
int TOKEN_BLOCK) {}
|
|
|
|
template <>
|
|
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
|
|
bfloat16_t *input,
|
|
bfloat16_t *bias,
|
|
bfloat16_t *residual,
|
|
float *reduce_weight,
|
|
int *cusum_token_count,
|
|
int *gather_ids,
|
|
int num_token,
|
|
int topk,
|
|
int num_expert,
|
|
int hidden_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
int HIDDEN_BLOCK,
|
|
int TOKEN_BLOCK) {}
|
|
#endif
|
|
|
|
} // namespace kernels
|
|
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
|
void *output,
|
|
const void *input,
|
|
const void *bias,
|
|
const void *residual,
|
|
const float *reduce_weight,
|
|
const int *cusum_token_count,
|
|
const int *gather_idx,
|
|
int num_token,
|
|
int topk,
|
|
int num_expert,
|
|
int hidden_size,
|
|
int start_expert_id,
|
|
int expert_size,
|
|
cnnlDataType_t dtype) {
|
|
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
|
|
std::cerr << "[invokeMoeCombineResultKernel]: "
|
|
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if (bias != nullptr) {
|
|
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
|
|
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
|
|
<< "cusum_token_count can not be nullptr.";
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
size_t data_bytes = 0;
|
|
cnnlGetSizeOfDataType(dtype, &data_bytes);
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
|
|
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
|
|
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
|
|
int convert_buffer = data_bytes == 2
|
|
? 3 * hidden_size * data_bytes
|
|
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
|
|
int max_input_size = (432 * 1024 - convert_buffer) /
|
|
(2 * topk * data_bytes + /*input size, double buffer*/
|
|
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
|
|
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
|
|
2 * data_bytes); /*output size, one buffer*/
|
|
|
|
int TOKEN_BLOCK = 1;
|
|
int HIDDEN_BLOCK = 1;
|
|
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
|
|
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
|
|
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
|
|
TOKEN_BLOCK = 1;
|
|
} else {
|
|
HIDDEN_BLOCK = hidden_size;
|
|
}
|
|
// for latency case, hidden_size is large but token is small.
|
|
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
|
|
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
|
|
}
|
|
|
|
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
|
|
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
|
|
task_dim_x =
|
|
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
|
|
uint32_t pad_dim_x = task_dim_x;
|
|
while (pad_dim_x <= cluster_num * core_num) {
|
|
if ((cluster_num * core_num % pad_dim_x == 0)) {
|
|
task_dim_x = pad_dim_x;
|
|
break;
|
|
}
|
|
pad_dim_x += core_num;
|
|
}
|
|
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
|
|
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
|
|
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
|
|
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
|
|
}
|
|
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
|
|
|
|
float max_cluster_num = core_num * cluster_num / task_dim_x;
|
|
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
|
|
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
|
|
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
|
|
|
|
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
|
|
if (dtype == CNNL_DTYPE_FLOAT) {
|
|
if (!is_large_tensor) {
|
|
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
|
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
|
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
} else {
|
|
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
|
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
|
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
}
|
|
} else if (dtype == CNNL_DTYPE_HALF) {
|
|
if (!is_large_tensor) {
|
|
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
|
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
|
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
} else {
|
|
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
|
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
|
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
}
|
|
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
|
if (!isBf16Supported()) {
|
|
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
if (!is_large_tensor) {
|
|
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
|
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
|
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
} else {
|
|
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
|
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
|
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
|
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
|
}
|
|
} else {
|
|
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
|
|
<< "among float/half/bfloat16." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|