/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include #include #include "cnrt.h" #include "combine_result.mluh" // clang-format off #include #include // clang-format on #if __BANG_ARCH__ >= 592 #include template using bang_fusor = bang::experimental::fusor; #endif namespace tmo { namespace kernels { #define NRAM_REMAIN_SIZE (32 * 1024) #define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; template __mlu_func__ void swap(T *&ping, T *&pong) { T *temp = ping; ping = pong; pong = temp; } #define GATHER_ASYNC_IO0(offset_type) \ __asm__ __volatile__( \ "gather.vector.async.nram.gdram.nram." #offset_type \ ".io0 [%[dst]], [%[src]], [%[offset]], " \ "%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \ [src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \ [transfer_num] "r"(token_count), [stride] "r"(transfer_size)) #define FUSE_MUL_CVT(dst_dtype) \ __asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \ ".f32 [%[dst]], [%[src0]], %[src1]," \ " %[size];\n\t" ::[dst] "r"(dst), \ [src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size)); #define FUSE_MULADD_CVT(dst_dtype) \ __asm__ __volatile__("muladd.nram.crn." #dst_dtype \ ".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \ " %[size], %[size];\n\t" ::[dst] "r"(dst), \ [src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size)); template __mlu_func__ void toFloat(float *dst, T *src, int count) { if (std::is_same::value) { __bang_half2float(dst, (half *)src, count); } else if (std::is_same::value) { __bang_bfloat162float(dst, (bfloat16_t *)src, count); } else if (std::is_same::value) { __bang_add_scalar((float *)dst, (float *)src, (float)0, count); } } template __mlu_func__ void floatTo(T *dst, float *src, int count) { if (std::is_same::value) { __bang_float2half_rn((half *)dst, src, count); } else if (std::is_same::value) { __bang_float2bfloat16_rn((bfloat16_t *)dst, src, count); } else if (std::is_same::value) { __bang_add_scalar((float *)dst, (float *)src, (float)0, count); } } __mlu_func__ void loadAsync2d(void *dst, void *src, int size, int dststride, int srcstride, int seg_num) { #if __BANG_ARCH__ > 500 __asm__ __volatile__( "ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]]," " %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst), [src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride), [segnum] "r"(seg_num)); #else __memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num); #endif } __mlu_func__ void storeAsync2d(void *dst, void *src, int size, int dststride, int srcstride, int seg_num) { #if __BANG_ARCH__ > 500 __asm__ __volatile__( "st.async.stride.gdram.nram.io1 [%[dst]], [%[src]]," " %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst), [src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride), [segnum] "r"(seg_num)); #else __memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num); #endif } template __mlu_func__ void gatherTokensAsync(void *dst, void *src_gdram, T_IDX *nram_offset, int transfer_size, int token_count) { if (token_count <= 0 || src_gdram == nullptr) return; #if __BANG_ARCH__ > 500 if (std::is_same::value) { GATHER_ASYNC_IO0(u32); } else { GATHER_ASYNC_IO0(u64); } #else for (int k = 0; k < token_count; k++) { __memcpy_async((int8_t *)dst + k * transfer_size, (int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM); } #endif } __mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx, int *nram_mask, uint8_t *nram_mask_char, int *nram_mask_buffer, int begin_expert_acc_tokens, int end_expert_acc_tokens, int token_count, bool expert_parallelism) { if (!expert_parallelism) { return token_count; } __bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count); #if __BANG_ARCH__ >= 592 bang_fusor(nram_mask, nram_token_idx, token_count) .ge(begin_expert_acc_tokens) .land(nram_mask_buffer) .cvt(0); #else __bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count); __bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count); __bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0); #endif __bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count); int active_token_count = __bang_count((float *)nram_mask, token_count); return active_token_count; } __mlu_func__ void computeOffset0(uint64_t *nram_offset, int *nram_idx, uint64_t mul_scalar, int64_t add_scalar, uint32_t token_count) { #if __BANG_ARCH__ > 592 __bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0); #else __bang_int322int64((int64_t *)nram_offset, nram_idx, token_count); #endif __bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count); __bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count); } __mlu_func__ void computeOffset0(uint32_t *nram_offset, int *nram_idx, uint32_t mul_scalar, int64_t add_scalar, uint32_t token_count) { __bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar, token_count); } template __mlu_func__ void computeOffset(T_IDX *nram_token_offset, T_IDX *nram_bias_offset, int *nram_token_idx, int *nram_expert_tables, int expert_num, int token_count, int active_token_count, int hidden_size, int local_hidden_begin, int dtype_size, int start_expert_id, int expert_size, int begin_expert_acc_tokens, bool has_bias) { // for large tensor, convert int322int64 then do multiply and add seperately. if (active_token_count <= 0) return; if (has_bias) { int *nram_bias_offset_temp = (int *)nram_token_offset; __bang_write_zero(nram_bias_offset, active_token_count); for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) { __bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i], active_token_count); __bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp, active_token_count); } __bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count); computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size, (T_IDX)local_hidden_begin * dtype_size, active_token_count); } int64_t offset = ((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size; computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset, active_token_count); } template __mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) { #if __BANG_ARCH__ > 500 if (std::is_same::value) { FUSE_MUL_CVT(bf16); } else if (std::is_same::value) { FUSE_MUL_CVT(f16); } else if (std::is_same::value) { __bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size); } #else __bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size); floatTo((T *)dst, (float *)dst, size); #endif } template __mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) { #if __BANG_ARCH__ > 500 if (std::is_same::value) { FUSE_MULADD_CVT(bf16); } else if (std::is_same::value) { FUSE_MULADD_CVT(f16); } else if (std::is_same::value) { __bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size, size); } #else __bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size, size); floatTo((T *)dst, (float *)dst, size); #endif } // weightedReduceSum with EP split // input [token_count, k, hidden_size], weight [token_count, k] // 1. input * weight // 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add template ::type = nullptr> __mlu_func__ void weightedReduceSum(T *output, T *input, float *weight, T *input_buffer, int8_t *og_mask, int topk, int hidden_size, int token_count, bool &is_ping) { float *nram_input_buffer = (float *)((half *)input_buffer + ((std::is_same::value || !is_ping) ? 0 : hidden_size)); T *output_base = output - ((std::is_same::value || is_ping) ? 0 : hidden_size); int32_t index[32]; float reg_weight[128]; int8_t *index_ = (int8_t *)index; int topk_divide_4 = PAD_UP(topk, 4) / 4; int token_use_count = 0; for (int t_i = 0; t_i < token_count; t_i++) { float *output_begin = (float *)(output_base + t_i * hidden_size); for (int i = 0; i < topk_divide_4; i++) { index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i); float *weight_begin = weight + t_i * topk + i * 4; reg_weight[i * 4] = __load_nram(weight_begin); if (i * 4 + 1 < topk) { reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1); } if (i * 4 + 2 < topk) { reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2); } if (i * 4 + 3 < topk) { reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3); } } int first_in_expert = 0; float expert_coeff; for (; first_in_expert < topk - 1; first_in_expert++) { bool in_expert_range = index_[first_in_expert]; if (!in_expert_range) continue; expert_coeff = reg_weight[first_in_expert]; toFloat(output_begin, input + token_use_count * hidden_size, hidden_size); __bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size); token_use_count++; break; } if (first_in_expert == topk - 1) { if (index_[topk - 1]) { expert_coeff = reg_weight[topk - 1]; toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); token_use_count++; mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); } else { __bang_write_zero((T *)output_begin, hidden_size); } } else { for (int j = first_in_expert + 1; j < topk - 1; j++) { bool in_expert_range = index_[j]; if (!in_expert_range) continue; expert_coeff = reg_weight[j]; toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); token_use_count++; __bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin, hidden_size, hidden_size); } if (index_[topk - 1]) { expert_coeff = reg_weight[topk - 1]; toFloat(nram_input_buffer, input + token_use_count * hidden_size, hidden_size); token_use_count++; mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); } else { floatTo((T *)output_begin, (float *)output_begin, hidden_size); } } } if (!is_ping && sizeof(T) < sizeof(float)) { __memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size, hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1, token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1, token_count * hidden_size * sizeof(T), 0); } is_ping = !is_ping; } // weightedReduceSum without EP split // input [token_count, k, hidden_size], weight [token_count, k] // 1. input * weight // 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add template ::type = nullptr> __mlu_func__ void weightedReduceSum(T *output, T *input, float *weight, T *input_buffer, int8_t *og_mask, int topk, int hidden_size, int token_count, bool &is_ping) { float *nram_input_buffer = (float *)((half *)input_buffer + ((std::is_same::value || !is_ping) ? 0 : hidden_size)); T *output_base = output - ((std::is_same::value || is_ping) ? 0 : hidden_size); if (topk == 1) { for (int i = 0; i < token_count; i++) { float expert_coeff = __load_nram(weight + i); toFloat(nram_input_buffer, input + i * hidden_size, hidden_size); mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size); } return; } for (int t_i = 0; t_i < token_count; t_i++) { float *output_begin = (float *)(output_base + t_i * hidden_size); float expert_coeff = __load_nram(weight + t_i * topk); toFloat(output_begin, input + t_i * topk * hidden_size, hidden_size); toFloat(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size); __bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size); expert_coeff = __load_nram(weight + t_i * topk + 1); for (int k_i = 2; k_i < topk; k_i++) { __bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin, hidden_size, hidden_size); expert_coeff = __load_nram(weight + t_i * topk + k_i); toFloat(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size); } mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size); } if (!is_ping && sizeof(T) < sizeof(float)) { __memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size, hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1, token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1, token_count * hidden_size * sizeof(T), 0); } is_ping = !is_ping; } template __mlu_global__ void MLUCombineMoeResultKernel(T *output, T *input, T *bias, T *residual, float *reduce_weight, int *cusum_token_count, int *gather_idx, int num_token, int topk, int num_expert, int hidden_size, int start_expert_id, int expert_size, int HIDDEN_BLOCK, int TOKEN_BLOCK) { if (__is_mpu()) { return; } int local_hidden_begin = taskIdX * HIDDEN_BLOCK; int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin); int task_avg_tokens = num_token / taskDimY; int task_remain_tokens = num_token % taskDimY; int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens); int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens); if (local_hidden_size <= 0) return; if (task_tokens <= 0) return; constexpr int int32_dtype_size = (int)sizeof(int); constexpr int fp32_dtype_size = (int)sizeof(float); int pad_num_expert = PAD_UP(num_expert + 1, 32); bool has_bias = bias != nullptr; bool has_residual = residual != nullptr; bool using_acc_sum = cusum_token_count != nullptr; bool expert_parallelism = expert_size < num_expert; int block_size = TOKEN_BLOCK * topk; int pad_block_size = PAD_UP(block_size, 64); int *nram_expert_tables = (int *)nram_buffer; int *nram_token_idx = nram_expert_tables + pad_num_expert; T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size); T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size); int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size); T *nram_input_ping = (T *)(nram_mask + pad_block_size); T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK; T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK; T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK; T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK; T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK; float *nram_weight_ping = (float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK); float *nram_weight_pong = nram_weight_ping + pad_block_size; int buffer_block_num = sizeof(T) > 2 ? 2 : 3; T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size); T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK; T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) + buffer_block_num * HIDDEN_BLOCK * sizeof(half)); int *nram_mask_buffer = (int *)nram_token_offset; uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK); int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk; int begin_expert_acc_tokens = 0; int end_expert_acc_tokens = num_token * topk; if (using_acc_sum) { __memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size, GDRAM2NRAM); } __memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk, init_token_count * sizeof(int), GDRAM2NRAM); __sync_io(); if (expert_parallelism) { begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id); end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size); } int active_token_count = getMaskAndActiveTokenCount( nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens, end_expert_acc_tokens, init_token_count, expert_parallelism); computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert, init_token_count, active_token_count, hidden_size, local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias); __sync_io_move_compute(true, false, false, false, false, true); __sync_io_move_compute(false, false, true, true, false, false); int next_active_token_count = active_token_count; int previous_global_token_begin = 0; int previous_token_count = 0; bool is_ping = false; for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) { int next_token_begin = (task_begin + 1) * TOKEN_BLOCK; int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK; bool is_last_loop = next_token_begin >= task_tokens; bool is_last_2_loop = next_next_token_begin >= task_tokens; int current_token_begin = task_begin * TOKEN_BLOCK; int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin); int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin); int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin); int current_global_token_begin = task_token_begin + current_token_begin; int next_global_token_begin = task_token_begin + next_token_begin; int next_next_global_token_begin = task_token_begin + next_next_token_begin; if (!is_last_loop) { if (!is_last_2_loop) { loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk, next_next_token_count * topk * sizeof(int), 0, 0, 0); } loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk, next_token_count * topk * fp32_dtype_size, 0, 0, 0); if (has_residual) { loadAsync2d(nram_residual_ping, residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, local_hidden_size * sizeof(T), local_hidden_size * sizeof(T), hidden_size * sizeof(T), next_token_count - 1); } gatherTokensAsync(nram_input_ping, input, nram_token_offset, local_hidden_size * sizeof(T), next_active_token_count); gatherTokensAsync(nram_bias_ping, bias, nram_bias_offset, local_hidden_size * sizeof(T), next_active_token_count); } if (task_begin >= 1) { storeAsync2d( output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T), local_hidden_size * sizeof(T), previous_token_count - 1); } if (task_begin >= 0) { if (has_bias && active_token_count) { __bang_add(nram_input_pong, nram_input_pong, nram_bias_pong, active_token_count * local_hidden_size); } if (expert_parallelism) { weightedReduceSum(nram_output_ping, nram_input_pong, nram_weight_pong, nram_input_buffer, (int8_t *)nram_mask_char, topk, local_hidden_size, current_token_count, is_ping); } else { weightedReduceSum(nram_output_ping, nram_input_pong, nram_weight_pong, nram_input_buffer, (int8_t *)nram_mask_char, topk, local_hidden_size, current_token_count, is_ping); } if (has_residual) { __bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong, current_token_count * local_hidden_size); } } __sync_io_move_compute(); active_token_count = next_active_token_count; if (expert_parallelism && !is_last_loop) { __bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk); } if (!is_last_2_loop) { next_active_token_count = getMaskAndActiveTokenCount( nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens, end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism); computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert, next_next_token_count * topk, next_active_token_count, hidden_size, local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias); } swap(nram_input_ping, nram_input_pong); swap(nram_bias_ping, nram_bias_pong); swap(nram_residual_ping, nram_residual_pong); swap(nram_weight_ping, nram_weight_pong); swap(nram_output_ping, nram_output_pong); previous_global_token_begin = current_global_token_begin; previous_token_count = current_token_count; } storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin, nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T), local_hidden_size * sizeof(T), previous_token_count - 1); } #if __BANG_ARCH__ < 500 template <> __mlu_global__ void MLUCombineMoeResultKernel(bfloat16_t *output, bfloat16_t *input, bfloat16_t *bias, bfloat16_t *residual, float *reduce_weight, int *cusum_token_count, int *gather_ids, int num_token, int topk, int num_expert, int hidden_size, int start_expert_id, int expert_size, int HIDDEN_BLOCK, int TOKEN_BLOCK) {} template <> __mlu_global__ void MLUCombineMoeResultKernel(bfloat16_t *output, bfloat16_t *input, bfloat16_t *bias, bfloat16_t *residual, float *reduce_weight, int *cusum_token_count, int *gather_ids, int num_token, int topk, int num_expert, int hidden_size, int start_expert_id, int expert_size, int HIDDEN_BLOCK, int TOKEN_BLOCK) {} #endif } // namespace kernels KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue, void *output, const void *input, const void *bias, const void *residual, const float *reduce_weight, const int *cusum_token_count, const int *gather_idx, int num_token, int topk, int num_expert, int hidden_size, int start_expert_id, int expert_size, cnnlDataType_t dtype) { if (topk > 128 || num_expert > 1024 || hidden_size < 256) { std::cerr << "[invokeMoeCombineResultKernel]: " << "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256."; return KernelStatus::KERNEL_STATUS_FAILED; } if (bias != nullptr) { std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias."; return KernelStatus::KERNEL_STATUS_FAILED; } if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) { std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, " << "cusum_token_count can not be nullptr."; return KernelStatus::KERNEL_STATUS_FAILED; } size_t data_bytes = 0; cnnlGetSizeOfDataType(dtype, &data_bytes); CNdev dev; cnCtxGetDevice(&dev); int cluster_num; int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); // 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer. // TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset. int convert_buffer = data_bytes == 2 ? 3 * hidden_size * data_bytes : 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32 int max_input_size = (432 * 1024 - convert_buffer) / (2 * topk * data_bytes + /*input size, double buffer*/ (bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/ (residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/ 2 * data_bytes); /*output size, one buffer*/ int TOKEN_BLOCK = 1; int HIDDEN_BLOCK = 1; int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64; if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) { HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK; TOKEN_BLOCK = 1; } else { HIDDEN_BLOCK = hidden_size; } // for latency case, hidden_size is large but token is small. if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) { HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num; } HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024); uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK; task_dim_x = (task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num); uint32_t pad_dim_x = task_dim_x; while (pad_dim_x <= cluster_num * core_num) { if ((cluster_num * core_num % pad_dim_x == 0)) { task_dim_x = pad_dim_x; break; } pad_dim_x += core_num; } HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x; HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64; if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) { TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK; } TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk); float max_cluster_num = core_num * cluster_num / task_dim_x; uint32_t task_dim_y = std::min(max_cluster_num, num_token); task_dim_y = task_dim_y < 1 ? 1 : task_dim_y; cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1}; bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX; if (dtype == CNNL_DTYPE_FLOAT) { if (!is_large_tensor) { kernels::MLUCombineMoeResultKernel<<>>( (float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } else { kernels::MLUCombineMoeResultKernel<<>>( (float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } } else if (dtype == CNNL_DTYPE_HALF) { if (!is_large_tensor) { kernels::MLUCombineMoeResultKernel<<>>( (half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } else { kernels::MLUCombineMoeResultKernel<<>>( (half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } } else if (dtype == CNNL_DTYPE_BFLOAT16) { if (!isBf16Supported()) { std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (!is_large_tensor) { kernels::MLUCombineMoeResultKernel<<>>( (bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } else { kernels::MLUCombineMoeResultKernel<<>>( (bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual, (float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK); } } else { std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is " << "among float/half/bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo