Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu
2026-02-04 17:39:32 +08:00

761 lines
36 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include "cnrt.h"
#include "combine_result.mluh"
// clang-format off
#include <bang_device_functions_extra.h>
#include <mlu.h>
// clang-format on
#if __BANG_ARCH__ >= 592
#include <bang_fusor.h>
template <typename SrcT>
using bang_fusor = bang::experimental::fusor<SrcT>;
#endif
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *temp = ping;
ping = pong;
pong = temp;
}
#define GATHER_ASYNC_IO0(offset_type) \
__asm__ __volatile__( \
"gather.vector.async.nram.gdram.nram." #offset_type \
".io0 [%[dst]], [%[src]], [%[offset]], " \
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
#define FUSE_MUL_CVT(dst_dtype) \
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1]," \
" %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
#define FUSE_MULADD_CVT(dst_dtype) \
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
__mlu_func__ void loadAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
#endif
}
__mlu_func__ void storeAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
#endif
}
template <typename T_IDX>
__mlu_func__ void gatherTokensAsync(void *dst,
void *src_gdram,
T_IDX *nram_offset,
int transfer_size,
int token_count) {
if (token_count <= 0 || src_gdram == nullptr) return;
#if __BANG_ARCH__ > 500
if (std::is_same<T_IDX, uint32_t>::value) {
GATHER_ASYNC_IO0(u32);
} else {
GATHER_ASYNC_IO0(u64);
}
#else
for (int k = 0; k < token_count; k++) {
__memcpy_async((int8_t *)dst + k * transfer_size,
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
}
#endif
}
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
int *nram_mask,
uint8_t *nram_mask_char,
int *nram_mask_buffer,
int begin_expert_acc_tokens,
int end_expert_acc_tokens,
int token_count,
bool expert_parallelism) {
if (!expert_parallelism) {
return token_count;
}
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
#if __BANG_ARCH__ >= 592
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
.ge(begin_expert_acc_tokens)
.land(nram_mask_buffer)
.cvt<float>(0);
#else
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
#endif
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
int active_token_count = __bang_count((float *)nram_mask, token_count);
return active_token_count;
}
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
int *nram_idx,
uint64_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
#endif
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
}
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
int *nram_idx,
uint32_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
token_count);
}
template <typename T_IDX>
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
T_IDX *nram_bias_offset,
int *nram_token_idx,
int *nram_expert_tables,
int expert_num,
int token_count,
int active_token_count,
int hidden_size,
int local_hidden_begin,
int dtype_size,
int start_expert_id,
int expert_size,
int begin_expert_acc_tokens,
bool has_bias) {
// for large tensor, convert int322int64 then do multiply and add seperately.
if (active_token_count <= 0) return;
if (has_bias) {
int *nram_bias_offset_temp = (int *)nram_token_offset;
__bang_write_zero(nram_bias_offset, active_token_count);
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
active_token_count);
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
active_token_count);
}
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
}
int64_t offset =
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
active_token_count);
}
template <typename T>
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MUL_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MUL_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
}
#else
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
template <typename T>
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MULADD_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MULADD_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
}
#else
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
// weightedReduceSum with EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
int32_t index[32];
float reg_weight[128];
int8_t *index_ = (int8_t *)index;
int topk_divide_4 = PAD_UP(topk, 4) / 4;
int token_use_count = 0;
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
for (int i = 0; i < topk_divide_4; i++) {
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
float *weight_begin = weight + t_i * topk + i * 4;
reg_weight[i * 4] = __load_nram(weight_begin);
if (i * 4 + 1 < topk) {
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
}
if (i * 4 + 2 < topk) {
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
}
if (i * 4 + 3 < topk) {
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
}
}
int first_in_expert = 0;
float expert_coeff;
for (; first_in_expert < topk - 1; first_in_expert++) {
bool in_expert_range = index_[first_in_expert];
if (!in_expert_range) continue;
expert_coeff = reg_weight[first_in_expert];
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
token_use_count++;
break;
}
if (first_in_expert == topk - 1) {
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
__bang_write_zero((T *)output_begin, hidden_size);
}
} else {
for (int j = first_in_expert + 1; j < topk - 1; j++) {
bool in_expert_range = index_[j];
if (!in_expert_range) continue;
expert_coeff = reg_weight[j];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
}
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
}
}
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
// weightedReduceSum without EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
if (topk == 1) {
for (int i = 0; i < token_count; i++) {
float expert_coeff = __load_nram(weight + i);
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
}
return;
}
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
float expert_coeff = __load_nram(weight + t_i * topk);
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + 1);
for (int k_i = 2; k_i < topk; k_i++) {
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + k_i);
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
}
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
template <typename T, typename T_IDX>
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
T *input,
T *bias,
T *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {
if (__is_mpu()) {
return;
}
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
int task_avg_tokens = num_token / taskDimY;
int task_remain_tokens = num_token % taskDimY;
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
if (local_hidden_size <= 0) return;
if (task_tokens <= 0) return;
constexpr int int32_dtype_size = (int)sizeof(int);
constexpr int fp32_dtype_size = (int)sizeof(float);
int pad_num_expert = PAD_UP(num_expert + 1, 32);
bool has_bias = bias != nullptr;
bool has_residual = residual != nullptr;
bool using_acc_sum = cusum_token_count != nullptr;
bool expert_parallelism = expert_size < num_expert;
int block_size = TOKEN_BLOCK * topk;
int pad_block_size = PAD_UP(block_size, 64);
int *nram_expert_tables = (int *)nram_buffer;
int *nram_token_idx = nram_expert_tables + pad_num_expert;
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
float *nram_weight_ping =
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
float *nram_weight_pong = nram_weight_ping + pad_block_size;
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
int *nram_mask_buffer = (int *)nram_token_offset;
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
int begin_expert_acc_tokens = 0;
int end_expert_acc_tokens = num_token * topk;
if (using_acc_sum) {
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
GDRAM2NRAM);
}
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
init_token_count * sizeof(int), GDRAM2NRAM);
__sync_io();
if (expert_parallelism) {
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
}
int active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, init_token_count, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
init_token_count, active_token_count, hidden_size, local_hidden_begin,
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
__sync_io_move_compute(true, false, false, false, false, true);
__sync_io_move_compute(false, false, true, true, false, false);
int next_active_token_count = active_token_count;
int previous_global_token_begin = 0;
int previous_token_count = 0;
bool is_ping = false;
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
bool is_last_loop = next_token_begin >= task_tokens;
bool is_last_2_loop = next_next_token_begin >= task_tokens;
int current_token_begin = task_begin * TOKEN_BLOCK;
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
int current_global_token_begin = task_token_begin + current_token_begin;
int next_global_token_begin = task_token_begin + next_token_begin;
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
if (!is_last_loop) {
if (!is_last_2_loop) {
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
next_next_token_count * topk * sizeof(int), 0, 0, 0);
}
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
if (has_residual) {
loadAsync2d(nram_residual_ping,
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
hidden_size * sizeof(T), next_token_count - 1);
}
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
local_hidden_size * sizeof(T), next_active_token_count);
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
local_hidden_size * sizeof(T), next_active_token_count);
}
if (task_begin >= 1) {
storeAsync2d(
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
if (task_begin >= 0) {
if (has_bias && active_token_count) {
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
active_token_count * local_hidden_size);
}
if (expert_parallelism) {
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
} else {
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
}
if (has_residual) {
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
current_token_count * local_hidden_size);
}
}
__sync_io_move_compute();
active_token_count = next_active_token_count;
if (expert_parallelism && !is_last_loop) {
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
}
if (!is_last_2_loop) {
next_active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
begin_expert_acc_tokens, has_bias);
}
swap(nram_input_ping, nram_input_pong);
swap(nram_bias_ping, nram_bias_pong);
swap(nram_residual_ping, nram_residual_pong);
swap(nram_weight_ping, nram_weight_pong);
swap(nram_output_ping, nram_output_pong);
previous_global_token_begin = current_global_token_begin;
previous_token_count = current_token_count;
}
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
#if __BANG_ARCH__ < 500
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
#endif
} // namespace kernels
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype) {
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
std::cerr << "[invokeMoeCombineResultKernel]: "
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (bias != nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
<< "cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
size_t data_bytes = 0;
cnnlGetSizeOfDataType(dtype, &data_bytes);
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
int convert_buffer = data_bytes == 2
? 3 * hidden_size * data_bytes
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
int max_input_size = (432 * 1024 - convert_buffer) /
(2 * topk * data_bytes + /*input size, double buffer*/
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
2 * data_bytes); /*output size, one buffer*/
int TOKEN_BLOCK = 1;
int HIDDEN_BLOCK = 1;
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
TOKEN_BLOCK = 1;
} else {
HIDDEN_BLOCK = hidden_size;
}
// for latency case, hidden_size is large but token is small.
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
}
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
task_dim_x =
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
uint32_t pad_dim_x = task_dim_x;
while (pad_dim_x <= cluster_num * core_num) {
if ((cluster_num * core_num % pad_dim_x == 0)) {
task_dim_x = pad_dim_x;
break;
}
pad_dim_x += core_num;
}
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
}
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
float max_cluster_num = core_num * cluster_num / task_dim_x;
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
if (dtype == CNNL_DTYPE_FLOAT) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_HALF) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else {
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
<< "among float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo