forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
521
torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu
Normal file
521
torch_mlu_ops-v1.3.2/csrc/kernels/moe/add_bias_activation.mlu
Normal file
@@ -0,0 +1,521 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "add_bias_activation.mluh"
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024)
|
||||
#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024)
|
||||
|
||||
__nram__ int8_t nram_buffer[USE_NRAM_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE];
|
||||
|
||||
__mlu_func__ void get_expert_info(int *nram_count,
|
||||
int *count_sram,
|
||||
uint32_t *gather_offset,
|
||||
int real_inner,
|
||||
int tokens_start,
|
||||
int tokens_end,
|
||||
int tokens_load,
|
||||
int expert_deal_start,
|
||||
int expert_deal_end,
|
||||
uint32_t &expert_start,
|
||||
uint32_t &expert_end,
|
||||
uint32_t &tokens_deal_first,
|
||||
int dtype_size) {
|
||||
bool record_start = false;
|
||||
// loop expert to find first and last deal expert in current core
|
||||
for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) {
|
||||
if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) {
|
||||
expert_start = expert_id;
|
||||
tokens_deal_first =
|
||||
std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1;
|
||||
record_start = true;
|
||||
}
|
||||
if (__load_nram(nram_count + expert_id + 1) > tokens_end) {
|
||||
expert_end = expert_id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// record expert offset to gather bias
|
||||
__bang_write_zero(gather_offset, tokens_load);
|
||||
int tokens_load_total = 0;
|
||||
for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) {
|
||||
int tokens_expand = __load_sram((int *)count_sram + expert_id);
|
||||
if (expert_id == expert_start) {
|
||||
tokens_expand = tokens_deal_first;
|
||||
} else if (expert_id == expert_end) {
|
||||
tokens_expand = tokens_load - tokens_load_total;
|
||||
}
|
||||
if (tokens_expand == 0) {
|
||||
continue;
|
||||
}
|
||||
__bang_write_value(gather_offset + tokens_load_total, tokens_expand,
|
||||
(int)((expert_id - expert_deal_start) * real_inner * dtype_size));
|
||||
tokens_load_total += tokens_expand;
|
||||
}
|
||||
}
|
||||
|
||||
/*************** functions for compute basic operation ***************/
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void add_bias(T *dst_src, T *bias, int number) {
|
||||
// cycle add bias
|
||||
__bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mul_left_right(T *left, T *right, int number) {
|
||||
__bang_mul((T *)left, (T *)left, (T *)right, number);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number);
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void do_activation(float *input_left,
|
||||
float *act_space,
|
||||
int number,
|
||||
float active_coef,
|
||||
cnnlActivationMode_t act_type) {
|
||||
if (act_type == CNNL_ACTIVATION_GELU) {
|
||||
__bang_active_gelu((float *)input_left, (float *)input_left, number);
|
||||
} else if (act_type == CNNL_ACTIVATION_SWISH) {
|
||||
float *tmp = input_left;
|
||||
if (active_coef != 1.0f) {
|
||||
__bang_mul_scalar(act_space, input_left, active_coef, number);
|
||||
tmp = act_space;
|
||||
}
|
||||
__bang_active_sigmoid((float *)act_space, (float *)tmp, number);
|
||||
__bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number);
|
||||
}
|
||||
}
|
||||
|
||||
/*************** functions for steps of each loop ***************/
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void gather_bias(T *bias_sram,
|
||||
T *bias_nram,
|
||||
uint32_t *gather_offset,
|
||||
int expert_start,
|
||||
int expert_end,
|
||||
int expert_deal_start,
|
||||
int tokens_deal_first,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
bool is_gated) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (is_gated) {
|
||||
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
|
||||
inner_size * sizeof(T), tokens_deal);
|
||||
__gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size,
|
||||
gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T),
|
||||
tokens_deal);
|
||||
} else {
|
||||
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
|
||||
inner_size * sizeof(T), tokens_deal);
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < tokens_deal; i++) {
|
||||
if (is_gated) {
|
||||
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
__memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size),
|
||||
(int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
} else {
|
||||
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
|
||||
inner_size * sizeof(T), SRAM2NRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void loadBiasInput(T *input,
|
||||
T *left,
|
||||
T *right,
|
||||
T *bias_nram,
|
||||
T *bias_sram,
|
||||
uint32_t *gather_offset,
|
||||
size_t input_offset,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
uint32_t expert_start,
|
||||
uint32_t expert_end,
|
||||
int expert_deal_start,
|
||||
uint32_t tokens_deal_first,
|
||||
bool is_gated,
|
||||
bool has_bias) {
|
||||
if (is_gated) {
|
||||
// if gated, stride io load input, left/right inner to input_left/right
|
||||
__memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM,
|
||||
inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
|
||||
__memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T),
|
||||
GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
|
||||
} else {
|
||||
// if not gated, load input to input_left total
|
||||
__memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
if (has_bias) {
|
||||
__sync_compute();
|
||||
gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end,
|
||||
expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void computeAddActivation(T *bias_nram,
|
||||
T *left_dst,
|
||||
T *input_right,
|
||||
float *input_left,
|
||||
float *act_space,
|
||||
int tokens_deal,
|
||||
int inner_size,
|
||||
bool is_gated,
|
||||
bool has_bias,
|
||||
float active_coef,
|
||||
cnnlActivationMode_t act_type) {
|
||||
int number = tokens_deal * inner_size;
|
||||
if (has_bias) {
|
||||
add_bias((T *)left_dst, (T *)bias_nram, number);
|
||||
if (is_gated) {
|
||||
add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number);
|
||||
}
|
||||
}
|
||||
|
||||
// cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float((float *)input_left, (half *)left_dst, number);
|
||||
}
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number);
|
||||
}
|
||||
#endif
|
||||
|
||||
// activation
|
||||
do_activation(input_left, act_space, number, active_coef, act_type);
|
||||
|
||||
// if half/bfloat16, cast float to T to mul
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_float2half((half *)input_left, (float *)input_left, number);
|
||||
}
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_gated) {
|
||||
mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void storeOutput(T *output,
|
||||
T *output_nram,
|
||||
size_t output_offset,
|
||||
int output_stride,
|
||||
int tokens_deal,
|
||||
int inner_size) {
|
||||
__memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM,
|
||||
output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUAddBiasActivationKernel(T *output,
|
||||
const T *input,
|
||||
const T *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef) {
|
||||
// if bias and token_count is nullptr, not add bias, only activation and gated mul.
|
||||
bool has_bias = (bias != nullptr);
|
||||
|
||||
// 1. distrubute nram space
|
||||
/* if not gated
|
||||
---------------------------- ----------------------------------
|
||||
| nram_token_count | bias_ping/pong |
|
||||
| num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) |
|
||||
---------------------------- ----------------------------------
|
||||
--------------------------------------
|
||||
| input_ping/pong |
|
||||
| 2 * x * inner_size * sizeof(float) |
|
||||
--------------------------------------
|
||||
------------------------------------------------- -------------------
|
||||
| act_space | gather_offset |
|
||||
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
|
||||
------------------------------------------------- -------------------
|
||||
*/
|
||||
|
||||
/* if gated
|
||||
---------------------------- ----------------------------------
|
||||
| nram_token_count | bias_ping/pong |
|
||||
| num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) |
|
||||
---------------------------- ----------------------------------
|
||||
-------------------------------------- ----------------------------------
|
||||
| input_left_ping/pong | input_right_ping/pong |
|
||||
| 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) |
|
||||
-------------------------------------- ----------------------------------
|
||||
------------------------------------------------- -------------------
|
||||
| act_space | gather_offset |
|
||||
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
|
||||
------------------------------------------------- -------------------
|
||||
*/
|
||||
|
||||
// distribute sram
|
||||
int8_t *count_sram = (int8_t *)sram_buffer;
|
||||
int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int);
|
||||
|
||||
// distrubute nram
|
||||
int real_inner = (is_gated) ? inner_size * 2 : inner_size;
|
||||
int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0;
|
||||
int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float);
|
||||
int gated_ext_size = is_gated ? sizeof(T) : 0;
|
||||
int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) /
|
||||
(2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size +
|
||||
2 * bias_nram_size + sizeof(int));
|
||||
|
||||
int8_t *nram_count = (int8_t *)nram_buffer;
|
||||
int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int);
|
||||
int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size;
|
||||
int8_t *input_right =
|
||||
(int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0);
|
||||
int8_t *act_space = (int8_t *)input_right +
|
||||
2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float));
|
||||
int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size;
|
||||
|
||||
// 2. cusum_token_count load to nram, because need to reuse in load bias.
|
||||
if (has_bias) {
|
||||
__memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
if (taskIdX == 0) {
|
||||
__bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert);
|
||||
__sync();
|
||||
__memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
|
||||
// 3. sram loop to compute
|
||||
// compute once load bias to sram due to sram_limit
|
||||
int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T));
|
||||
int real_expert = cusum_token_count == nullptr ? num_expert : expert_size;
|
||||
int sram_loop_rem = real_expert % max_expert_deal;
|
||||
int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0);
|
||||
if (!has_bias) {
|
||||
max_expert_deal = real_expert;
|
||||
sram_loop = 1;
|
||||
sram_loop_rem = 0;
|
||||
}
|
||||
for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) {
|
||||
// load current bias, compute each core deal number
|
||||
int expert_deal =
|
||||
(deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal;
|
||||
int expert_deal_start = deal_loop * max_expert_deal + start_expert_id;
|
||||
int expert_deal_end = expert_deal_start + expert_deal - 1;
|
||||
__sync_all();
|
||||
if (has_bias && __is_mpu()) {
|
||||
__memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner,
|
||||
expert_deal * real_inner * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
__sync_all();
|
||||
|
||||
// get tokens info of each core
|
||||
int tokens_total_cur = total_tokens;
|
||||
if (has_bias) {
|
||||
tokens_total_cur =
|
||||
__load_nram((int *)nram_count + expert_deal_end + 1) -
|
||||
(start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start));
|
||||
} else if (cusum_token_count != nullptr) {
|
||||
tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) -
|
||||
__load_gdram(cusum_token_count + expert_deal_start);
|
||||
}
|
||||
|
||||
if (sram_loop != 1) {
|
||||
tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] -
|
||||
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
|
||||
if (deal_loop == 0 && start_expert_id != 0) {
|
||||
tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start);
|
||||
}
|
||||
}
|
||||
|
||||
int tokens_core_rem = tokens_total_cur % taskDim;
|
||||
int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem);
|
||||
if (tokens_cur_core == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if ep, input start in current token, have a real start in total network
|
||||
int real_start =
|
||||
cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0;
|
||||
int tokens_core_start = tokens_cur_core * taskId +
|
||||
(taskId < tokens_core_rem ? 0 : tokens_core_rem) +
|
||||
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
|
||||
if (deal_loop != 0) {
|
||||
tokens_core_start -= real_start;
|
||||
}
|
||||
|
||||
uint32_t expert_start = 0;
|
||||
uint32_t expert_end = 0;
|
||||
uint32_t tokens_deal_first = 0;
|
||||
|
||||
// 4. nram loop compute
|
||||
int nram_loop_rem = tokens_cur_core % max_token_deal;
|
||||
int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0);
|
||||
int tokens_load = max_token_deal;
|
||||
int tokens_compute = max_token_deal;
|
||||
int tokens_store = max_token_deal;
|
||||
|
||||
for (int loop = 0; loop < nram_loop + 2; loop++) {
|
||||
int inner_io_offset = (loop % 2) * max_token_deal * inner_size;
|
||||
int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size;
|
||||
int real_io_offset = (loop % 2) * max_token_deal * real_inner;
|
||||
int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner;
|
||||
if (nram_loop_rem != 0) {
|
||||
if (loop > 1 && (loop - 2) == (nram_loop - 1)) {
|
||||
tokens_store = nram_loop_rem;
|
||||
}
|
||||
if (loop > 0 && (loop - 1) == (nram_loop - 1)) {
|
||||
tokens_compute = nram_loop_rem;
|
||||
}
|
||||
if (loop == (nram_loop - 1)) {
|
||||
tokens_load = nram_loop_rem;
|
||||
}
|
||||
}
|
||||
int tokens_cur_start = tokens_core_start + loop * max_token_deal;
|
||||
int tokens_cur_end = tokens_cur_start + tokens_load - 1;
|
||||
// get current load info
|
||||
if (loop < nram_loop && has_bias) {
|
||||
get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner,
|
||||
tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load,
|
||||
expert_deal_start, expert_deal_end, expert_start, expert_end,
|
||||
tokens_deal_first, sizeof(T));
|
||||
}
|
||||
|
||||
// store
|
||||
if (loop > 1) {
|
||||
size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride;
|
||||
storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset,
|
||||
output_stride, tokens_store, inner_size);
|
||||
}
|
||||
|
||||
// compute
|
||||
if (loop > 0 && loop <= nram_loop) {
|
||||
T *left_dst = (T *)((float *)input_left + inner_com_offset) +
|
||||
((std::is_same<T, float>::value) ? 0 : tokens_compute * inner_size);
|
||||
computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst,
|
||||
(T *)input_right + inner_com_offset,
|
||||
(float *)input_left + inner_com_offset, (float *)act_space,
|
||||
tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type);
|
||||
}
|
||||
|
||||
// load
|
||||
if (loop < nram_loop) {
|
||||
T *left_dst = (T *)((float *)input_left + inner_io_offset) +
|
||||
((std::is_same<T, float>::value) ? 0 : tokens_load * inner_size);
|
||||
size_t input_offset = tokens_cur_start * real_inner;
|
||||
loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset,
|
||||
(T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset,
|
||||
input_offset, tokens_load, inner_size, expert_start, expert_end,
|
||||
expert_deal_start, tokens_deal_first, is_gated, has_bias);
|
||||
}
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
cnnlDataType_t dtype,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef) {
|
||||
if (bias != NULL && cusum_token_count == NULL) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
|
||||
<< "when have bias, cusum_token_count can not be nullptr.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
|
||||
<< "activation mode only supports gelu and swish now.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert,
|
||||
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
|
||||
active_coef);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert,
|
||||
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
|
||||
active_coef);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias,
|
||||
cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type,
|
||||
start_expert_id, expert_size, active_coef);
|
||||
} else {
|
||||
std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, "
|
||||
<< "only support float/half/bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,84 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Add bias and activate to all tokens. Different expert with different bias.
|
||||
* If in gated mode, add bias to all input. But only activate in the
|
||||
* input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:].
|
||||
* Else, add bias and activate to all input, has no multiply process.
|
||||
* @example
|
||||
* is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2
|
||||
* input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3],
|
||||
* [3, 5, 7, 8], [6, 8, 5, 3],
|
||||
* [2, 3, 4 ,5], [2, 9, 2, 3]]
|
||||
* bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]]
|
||||
* token_count = [2, 2, 1, 1]
|
||||
* first step: add bias
|
||||
* [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0],
|
||||
* [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2],
|
||||
* [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]]
|
||||
* second step: act and mul
|
||||
* output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3],
|
||||
* [act(3)*9, act(6)*10], [act(6)*7, act(9)*5],
|
||||
* [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param output: Output. Pointer to the MLU memory that stores the result.
|
||||
* When is_gated is true, The shape is [total_tokens, input_size / 2].
|
||||
* In this case, the input_size must be even. Otherwise the shape is [total_tokens,
|
||||
* input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride.
|
||||
* @param input: Input. Pointer to the MLU memory that stores the input tokens.
|
||||
* The shape is [total_tokens, input_size].
|
||||
* When is_gated is true, the shape is [total_tokens, 2 * inner_size].
|
||||
* Otherwise the shape is [total_tokens, inner_size].
|
||||
* @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be
|
||||
* continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape
|
||||
* is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process.
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token
|
||||
* counts. The shape is [num_expert + 1]. If cusum_token_count
|
||||
* is not nullptr, cusum_token_count, start_expert_id and
|
||||
* expert_size together determine which tokens to process.
|
||||
* If cusum_token_count is nullptr, process all tokens,
|
||||
* the number of which is total_tokens. When bias is not nullptr,
|
||||
* cusum_token_count must also not be nullptr.
|
||||
* @param num_expert: The number of expert.
|
||||
* @param total_tokens: The total number of tokens.
|
||||
* @param inner_size: The inner size of output.
|
||||
* @param output_stride: The stride of output, must be greater than or equal to inner_size.
|
||||
* @param dtype: Data type.
|
||||
* @param is_gated: Gated or not.
|
||||
* @param act_type: The type of activation. Support gelu and swish.
|
||||
* @param start_expert_id: The index of the start expert.
|
||||
* @param expert_size: The number of experts to process.
|
||||
* @param active_coef: The coefficient used in the swish activation.
|
||||
*/
|
||||
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const int *cusum_token_count,
|
||||
int num_expert,
|
||||
int total_tokens,
|
||||
int inner_size,
|
||||
int output_stride,
|
||||
cnnlDataType_t dtype,
|
||||
bool is_gated,
|
||||
cnnlActivationMode_t act_type,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
float active_coef);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
|
||||
646
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu
Normal file
646
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu
Normal file
@@ -0,0 +1,646 @@
|
||||
#include <stdint.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cast_gating.mluh"
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
namespace tmo {
|
||||
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
|
||||
|
||||
#define NRAM_BUFFER_SIZE (496 * 1024)
|
||||
#define WRAM_BUFFER_SIZE (512 * 1024)
|
||||
#define SRAM_BUFFER_SIZE (2032 * 1024)
|
||||
|
||||
#ifndef ONE_LINE
|
||||
#define ONE_LINE 64
|
||||
#endif
|
||||
|
||||
#ifndef LT_NUM
|
||||
#define LT_NUM 64
|
||||
#endif
|
||||
|
||||
struct castGatingTileInfo {
|
||||
int32_t block = 64;
|
||||
int32_t split_k_num = 8;
|
||||
int32_t block_k = 256;
|
||||
};
|
||||
|
||||
namespace kernels {
|
||||
#pragma bang walign(16)
|
||||
#ifndef ROW_PER_LT
|
||||
#define ROW_PER_LT 4
|
||||
#endif
|
||||
|
||||
#ifndef LT_SIZE
|
||||
#define LT_SIZE 16
|
||||
#endif
|
||||
|
||||
#ifndef WRAM_LT_MAP16_STRIDE
|
||||
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
|
||||
#endif
|
||||
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
|
||||
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
|
||||
do { \
|
||||
uint32_t align_num = 64 / src_dsize; \
|
||||
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
|
||||
uint32_t rem = size % 64; \
|
||||
if (n) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.nram.sram.b16" \
|
||||
" [%[dst_addr]], [%[src_addr]], " \
|
||||
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
||||
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
||||
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
||||
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
|
||||
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
|
||||
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
|
||||
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
|
||||
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
|
||||
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
|
||||
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
||||
} \
|
||||
\
|
||||
if (rem) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.nram.sram.b16" \
|
||||
" [%[dst_addr]], [%[src_addr]], " \
|
||||
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
||||
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
||||
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
||||
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
|
||||
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
|
||||
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
|
||||
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
|
||||
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
|
||||
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
|
||||
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
|
||||
}
|
||||
|
||||
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
|
||||
convert_type) \
|
||||
int align_n = PAD_DOWN(n, LT_NUM); \
|
||||
int sn0 = ONE_LINE; \
|
||||
int size_sn0 = sn0 / src_dsize; \
|
||||
int sn1 = ONE_LINE / src_dsize; \
|
||||
int ss1 = total_k * src_dsize; \
|
||||
int sn3 = k / size_sn0; \
|
||||
int sn4 = align_n / sn1; \
|
||||
int ss4 = sn1 * ss1; \
|
||||
int dn0 = sn0; \
|
||||
int dn1 = ROW_PER_LT; \
|
||||
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
|
||||
int ds1 = dst_k * dst_dsize; \
|
||||
int dn2 = sn1 / ROW_PER_LT; \
|
||||
int ds2 = WRAM_LT_MAP16_STRIDE; \
|
||||
int ds3 = sn0 * dst_dsize / src_dsize; \
|
||||
int dn4 = LT_SIZE / dn2; \
|
||||
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
|
||||
int dn5 = align_n / LT_NUM; \
|
||||
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
|
||||
int rem_k = k % size_sn0; \
|
||||
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
|
||||
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
|
||||
if (align_n > 0 && sn3 > 0) { \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
||||
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
||||
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
||||
";\n\t" ::[dst_addr] "r"(wram_dst), \
|
||||
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
|
||||
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
|
||||
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
|
||||
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
|
||||
sram_src += align_n * total_k; \
|
||||
wram_dst += align_n / LT_SIZE * dst_k; \
|
||||
} \
|
||||
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
|
||||
if (align_n > 0 && sn3 > 0) { \
|
||||
sn1 = align_n; \
|
||||
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
|
||||
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
|
||||
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
|
||||
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
|
||||
[dst_s3] "r"(ds3)); \
|
||||
sram_src += align_n * total_k; \
|
||||
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
|
||||
} \
|
||||
if (rem_k > 0) { \
|
||||
align_n = PAD_UP(n, LT_NUM); \
|
||||
sn0 = rem_k * src_dsize; \
|
||||
dn0 = sn0; \
|
||||
__asm__ __volatile__( \
|
||||
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
||||
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
||||
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
||||
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
||||
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
||||
";\n\t" ::[dst_addr] "r"(wram_dst2), \
|
||||
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
|
||||
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
|
||||
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
|
||||
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
|
||||
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
|
||||
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
|
||||
[dst_s5] "r"(0)); \
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
||||
half *sram_src,
|
||||
int32_t warp_n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
|
||||
".cvt.rn.f32.f16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
||||
bfloat16_t *sram_src,
|
||||
int32_t warp_n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
|
||||
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void warp_prompt_weight(T *wram_dst,
|
||||
T *sram_src,
|
||||
int32_t n,
|
||||
int32_t len_k,
|
||||
int32_t total_k) {
|
||||
int32_t type_size = sizeof(T);
|
||||
int32_t data_size = len_k * type_size;
|
||||
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
|
||||
int32_t ss0 = total_k * type_size;
|
||||
int32_t count = n / LT_NUM;
|
||||
for (int32_t i = 0; i < count; ++i) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
||||
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
|
||||
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
|
||||
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
|
||||
}
|
||||
count = n % LT_NUM / ROW_PER_LT;
|
||||
if (count > 0) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
||||
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
|
||||
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
|
||||
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
|
||||
}
|
||||
count = n % ROW_PER_LT;
|
||||
if (count > 0) {
|
||||
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
|
||||
const int32_t &taskid,
|
||||
const int32_t &taskdim,
|
||||
int32_t &task_offset,
|
||||
int32_t &num_cur_task) {
|
||||
int32_t num_per_task = num_total_task / taskdim;
|
||||
int32_t rem_idx = num_total_task % taskdim;
|
||||
if (taskid < rem_idx) {
|
||||
task_offset = taskid * (num_per_task + 1);
|
||||
num_cur_task = num_per_task + 1;
|
||||
} else {
|
||||
task_offset = taskid * num_per_task + rem_idx;
|
||||
num_cur_task = num_per_task;
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void bidirectionBarrierOp() {
|
||||
int32_t bcnt = coreDim + 1;
|
||||
if (__is_ipu()) {
|
||||
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
} else {
|
||||
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
|
||||
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
|
||||
}
|
||||
|
||||
__mlu_func__ void warp_store(void *ddr_dst,
|
||||
void *nram_src,
|
||||
const int32_t data_num,
|
||||
const int32_t dst_stride,
|
||||
const int32_t src_stride,
|
||||
const int32_t count,
|
||||
const int32_t dt_size) {
|
||||
if (src_stride == data_num && dst_stride == data_num) {
|
||||
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
|
||||
} else {
|
||||
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
|
||||
src_stride * dt_size, count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tc, typename Tcc>
|
||||
__mlu_func__ void splitKReduce(Tcc *workspace,
|
||||
Tc *output,
|
||||
int32_t M,
|
||||
int32_t N,
|
||||
int32_t split_k_num,
|
||||
int32_t ldc) {
|
||||
int32_t offset_m, cta_m;
|
||||
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
|
||||
if (cta_m <= 0) return;
|
||||
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
|
||||
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
|
||||
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
|
||||
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
|
||||
Tc *output_ddr = (Tc *)output + offset_m * ldc;
|
||||
for (int32_t i = 0; i < repeat; i++) {
|
||||
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
|
||||
int32_t data_size = N * sizeof(Tc);
|
||||
int32_t data_num = current_m - 1;
|
||||
if (ldc == N) {
|
||||
data_size = current_m * N * sizeof(Tc);
|
||||
data_num = 0;
|
||||
}
|
||||
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
|
||||
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
|
||||
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
|
||||
split_k_num, 1, 1, 1);
|
||||
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
|
||||
N * sizeof(Tc), data_num);
|
||||
workspace_ddr = workspace_ddr + block_m * N;
|
||||
output_ddr = output_ddr + block_m * ldc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Ta,
|
||||
typename Tac,
|
||||
typename Tb,
|
||||
typename Tbc,
|
||||
typename Tc,
|
||||
typename Tcc,
|
||||
bool EXCHANGE_AB>
|
||||
__mlu_global__ void MLUCastGating(Ta *A,
|
||||
Tb *B,
|
||||
Tc *C,
|
||||
Tcc *workspace,
|
||||
int32_t M,
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
int32_t lda,
|
||||
int32_t ldb,
|
||||
int32_t ldc,
|
||||
castGatingTileInfo split_info) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
int32_t block_k = split_info.block_k;
|
||||
int32_t grid_dimx = split_info.split_k_num;
|
||||
int32_t block = split_info.block;
|
||||
int32_t grid_idx = clusterId % grid_dimx;
|
||||
int32_t grid_idy = clusterId / grid_dimx;
|
||||
int32_t offset_k = 0, problem_k = 0;
|
||||
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
|
||||
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
|
||||
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
|
||||
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
|
||||
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
|
||||
int32_t warp_m = cta_m, warp_offset_m = 0;
|
||||
int32_t warp_n = cta_n, warp_offset_n = 0;
|
||||
int32_t outer_loop = 0, outer_rem = 0;
|
||||
if (EXCHANGE_AB) {
|
||||
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
|
||||
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
|
||||
if (cta_n > block && cta_n % block != 0) {
|
||||
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
|
||||
if (block_tmp < block) block = block_tmp;
|
||||
}
|
||||
outer_loop = (cta_n + block - 1) / block;
|
||||
outer_rem = cta_n % block == 0 ? block : cta_n % block;
|
||||
} else {
|
||||
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
|
||||
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
|
||||
if (cta_m > block && cta_m % block != 0) {
|
||||
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
|
||||
if (block_tmp < block) block = block_tmp;
|
||||
}
|
||||
outer_loop = (cta_m + block - 1) / block;
|
||||
outer_rem = cta_m % block == 0 ? block : cta_m % block;
|
||||
}
|
||||
|
||||
int32_t size_nram_buf =
|
||||
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
|
||||
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
|
||||
Tac *nbuf_a = (Tac *)nram_buffer;
|
||||
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
|
||||
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
|
||||
|
||||
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
|
||||
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
|
||||
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
|
||||
Ta *sbuf_a = (Ta *)sram_buffer;
|
||||
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
|
||||
|
||||
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
|
||||
Tbc *wbuf_b = (Tbc *)wram_buffer;
|
||||
|
||||
int32_t a_dsize = sizeof(Ta);
|
||||
int32_t b_dsize = sizeof(Tb);
|
||||
int32_t k_loop_count = 0;
|
||||
for (int32_t j = 0; j < outer_loop; j++) {
|
||||
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
|
||||
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
|
||||
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
|
||||
if (EXCHANGE_AB) {
|
||||
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
|
||||
} else {
|
||||
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
|
||||
}
|
||||
int32_t compute_total = warp_m * warp_n;
|
||||
if (compute_total > 0 && __is_ipu()) {
|
||||
if (!EXCHANGE_AB) {
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
}
|
||||
__bang_write_zero((Tcc *)nbuf_c, compute_total);
|
||||
}
|
||||
int32_t i = 0;
|
||||
for (; i < k_loop; i++) {
|
||||
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
|
||||
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
|
||||
cta_k = i == k_loop - 1 ? rem_k : block_k;
|
||||
if (__is_mpu()) {
|
||||
if (EXCHANGE_AB) {
|
||||
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
|
||||
current_block - 1);
|
||||
__asm__ volatile(
|
||||
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
||||
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
|
||||
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
|
||||
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
|
||||
} else {
|
||||
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
|
||||
current_block - 1);
|
||||
__asm__ volatile(
|
||||
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
||||
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
|
||||
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
|
||||
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
|
||||
}
|
||||
a_ddr = (Ta *)a_ddr + block_k;
|
||||
b_ddr = (Tb *)b_ddr + block_k;
|
||||
}
|
||||
bidirectionBarrierOp();
|
||||
if (__is_ipu() && compute_total > 0) {
|
||||
__sync_io_move_compute(false, true, false, false, false, true);
|
||||
__sync_io_move_compute(false, false, true, false, true, false);
|
||||
if (i >= 1) {
|
||||
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
||||
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
|
||||
}
|
||||
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
|
||||
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
|
||||
// mvdma bound for EXCHANGE_AB when n==32
|
||||
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
|
||||
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
|
||||
}
|
||||
k_loop_count += 1;
|
||||
}
|
||||
if (compute_total > 0) {
|
||||
__sync_io_move_compute(false, true, false, false, false, true);
|
||||
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
||||
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
|
||||
if (EXCHANGE_AB) {
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
|
||||
}
|
||||
int32_t total_offset =
|
||||
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
|
||||
: (offset_m + warp_offset_m + block * j) * N);
|
||||
Tcc *wks = (Tcc *)workspace + total_offset;
|
||||
int32_t store_c_size = sizeof(Tcc);
|
||||
int8_t *store_ddr = (int8_t *)wks;
|
||||
int32_t dst_str = EXCHANGE_AB ? M : N;
|
||||
if (grid_dimx == 1) {
|
||||
// convert Tcc to Tc
|
||||
dst_str = ldc;
|
||||
store_ddr =
|
||||
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
|
||||
: (offset_m + warp_offset_m + block * j) * ldc));
|
||||
}
|
||||
__asm__ volatile("sync.psimd.cio;\n\t");
|
||||
if (EXCHANGE_AB) {
|
||||
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
|
||||
} else {
|
||||
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_dimx != 1) {
|
||||
__sync_all();
|
||||
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
|
||||
split_info.split_k_num, ldc);
|
||||
}
|
||||
#endif // __BANG_ARCH__ >= 500
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
int32_t getBlock(int32_t m,
|
||||
int32_t n,
|
||||
int32_t core_num,
|
||||
int32_t block_k,
|
||||
int32_t a_dtype_size,
|
||||
int32_t b_dtype_size,
|
||||
int32_t compute_dtype_size,
|
||||
bool EXCHANGE_AB) {
|
||||
int32_t block = 0;
|
||||
if (EXCHANGE_AB) {
|
||||
int32_t block_m = n;
|
||||
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
|
||||
(2 * block_m * compute_dtype_size) * core_num;
|
||||
int32_t wram_block_n =
|
||||
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
|
||||
int32_t sram_block_n =
|
||||
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
|
||||
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
|
||||
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
|
||||
return block_n > 0 ? block_n : block_n_tmp;
|
||||
} else {
|
||||
int32_t block_n = n;
|
||||
int32_t nram_block_m =
|
||||
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
|
||||
int32_t sram_block_m =
|
||||
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
|
||||
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
|
||||
return block;
|
||||
}
|
||||
}
|
||||
|
||||
void gatingTiling(int32_t m,
|
||||
int32_t n,
|
||||
int32_t k,
|
||||
size_t a_dtype_size,
|
||||
size_t b_dtype_size,
|
||||
size_t compute_dtype_size,
|
||||
size_t workspace_size,
|
||||
int32_t union_number,
|
||||
int32_t core_num,
|
||||
int32_t &block,
|
||||
int32_t &split_k_num,
|
||||
int32_t &block_k,
|
||||
bool &EXCHANGE_AB) {
|
||||
block_k = std::min(k, int32_t(512 / a_dtype_size));
|
||||
split_k_num = 1;
|
||||
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
|
||||
if (m >= core_num * LT_NUM &&
|
||||
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
|
||||
EXCHANGE_AB = true;
|
||||
}
|
||||
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
|
||||
compute_dtype_size, EXCHANGE_AB);
|
||||
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
|
||||
block = tmp_block;
|
||||
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
|
||||
for (int32_t i = total_blocks; i <= union_number; i++) {
|
||||
if (union_number % i == 0) {
|
||||
int32_t tmp_split_k = union_number / i;
|
||||
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
|
||||
if (workspace_size >= workspace_size_need) {
|
||||
split_k_num = tmp_split_k;
|
||||
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
|
||||
if (EXCHANGE_AB && block > LT_NUM * core_num) {
|
||||
block = PAD_DOWN(block, LT_NUM * core_num);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
|
||||
}
|
||||
|
||||
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *filter,
|
||||
void *output,
|
||||
int input_row,
|
||||
int expert_num,
|
||||
int hidden_size,
|
||||
cnnlDataType_t a_dtype,
|
||||
void *workspace,
|
||||
size_t workspace_size_bytes) {
|
||||
if (is_arch300()) {
|
||||
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (expert_num > 128) {
|
||||
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
|
||||
std::cerr
|
||||
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (workspace_size_bytes > 0 && workspace == NULL) {
|
||||
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
|
||||
"greater than 0."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
int32_t union_number, core_num;
|
||||
getContxtInfo(&union_number, &core_num);
|
||||
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
|
||||
cnrtDim3_t dim;
|
||||
dim.x = (int32_t)func_type;
|
||||
dim.y = 1;
|
||||
dim.z = 1;
|
||||
|
||||
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
|
||||
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
|
||||
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
|
||||
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
|
||||
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
|
||||
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
|
||||
castGatingTileInfo split_info;
|
||||
bool EXCHANGE_AB = false;
|
||||
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
|
||||
workspace_size_bytes, union_number, core_num, split_info.block,
|
||||
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
|
||||
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (EXCHANGE_AB) {
|
||||
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
|
||||
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
|
||||
(float *)workspace, expert_num, input_row, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
} else {
|
||||
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
|
||||
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
|
||||
(float *)workspace, input_row, expert_num, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
}
|
||||
} else if (a_dtype == CNNL_DTYPE_HALF) {
|
||||
if (EXCHANGE_AB) {
|
||||
kernels::MLUCastGating<float, float, half, float, float, float, true>
|
||||
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
|
||||
(float *)workspace, expert_num, input_row, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
} else {
|
||||
kernels::MLUCastGating<half, float, float, float, float, float, false>
|
||||
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
|
||||
(float *)workspace, input_row, expert_num, hidden_size,
|
||||
hidden_size, hidden_size, expert_num, split_info);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
} // namespace tmo
|
||||
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
50
torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mluh
Normal file
@@ -0,0 +1,50 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
#define CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Convert input to float32 and do gating operation.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param input: Input. Pointer to the MLU memory that stores the input,
|
||||
* the shape must be [input_row, hidden_size].
|
||||
* @param filter: Input. Pointer to the MLU memory that stores the weight,
|
||||
* the shape must be [expert_num, hidden_size].
|
||||
* @param output: Output. Pointer to the MLU memory that stores the output,
|
||||
* the shape must be [input_row, expert_num].
|
||||
* @param input_row: Input.
|
||||
* @param expert_num: Input.
|
||||
* @param hidden_size: Input.
|
||||
* @param a_dtype: Input. The data-type of input.
|
||||
* @param workspace: Input. Pointer to the MLU workspace.
|
||||
* @param workspace_size_bytes: Input. The size of workspace in bytes.
|
||||
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
|
||||
* expert_num must be in range [1, 128].
|
||||
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
|
||||
* The data-type of filter and output must be float.
|
||||
* cast_gating only supports MLU500 device or higher.
|
||||
*/
|
||||
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
||||
void *input,
|
||||
void *filter,
|
||||
void *output,
|
||||
int input_row,
|
||||
int expert_num,
|
||||
int hidden_size,
|
||||
cnnlDataType_t a_dtype,
|
||||
void *workspace,
|
||||
size_t workspace_size_bytes);
|
||||
} // namespace tmo
|
||||
#endif // CSRC_KERNELS_CAST_GATING_MLUH_
|
||||
760
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu
Normal file
760
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mlu
Normal file
@@ -0,0 +1,760 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "cnrt.h"
|
||||
#include "combine_result.mluh"
|
||||
// clang-format off
|
||||
#include <bang_device_functions_extra.h>
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
#if __BANG_ARCH__ >= 592
|
||||
#include <bang_fusor.h>
|
||||
template <typename SrcT>
|
||||
using bang_fusor = bang::experimental::fusor<SrcT>;
|
||||
#endif
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
#define NRAM_REMAIN_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
|
||||
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void swap(T *&ping, T *&pong) {
|
||||
T *temp = ping;
|
||||
ping = pong;
|
||||
pong = temp;
|
||||
}
|
||||
|
||||
#define GATHER_ASYNC_IO0(offset_type) \
|
||||
__asm__ __volatile__( \
|
||||
"gather.vector.async.nram.gdram.nram." #offset_type \
|
||||
".io0 [%[dst]], [%[src]], [%[offset]], " \
|
||||
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
|
||||
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
|
||||
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
|
||||
|
||||
#define FUSE_MUL_CVT(dst_dtype) \
|
||||
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
|
||||
".f32 [%[dst]], [%[src0]], %[src1]," \
|
||||
" %[size];\n\t" ::[dst] "r"(dst), \
|
||||
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
||||
|
||||
#define FUSE_MULADD_CVT(dst_dtype) \
|
||||
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
|
||||
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
|
||||
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
|
||||
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void toFloat(float *dst, T *src, int count) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float(dst, (half *)src, count);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void floatTo(T *dst, float *src, int count) {
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_float2half_rn((half *)dst, src, count);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void loadAsync2d(void *dst,
|
||||
void *src,
|
||||
int size,
|
||||
int dststride,
|
||||
int srcstride,
|
||||
int seg_num) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__asm__ __volatile__(
|
||||
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
|
||||
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
||||
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
||||
[segnum] "r"(seg_num));
|
||||
#else
|
||||
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ void storeAsync2d(void *dst,
|
||||
void *src,
|
||||
int size,
|
||||
int dststride,
|
||||
int srcstride,
|
||||
int seg_num) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
__asm__ __volatile__(
|
||||
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
|
||||
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
|
||||
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
|
||||
[segnum] "r"(seg_num));
|
||||
#else
|
||||
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T_IDX>
|
||||
__mlu_func__ void gatherTokensAsync(void *dst,
|
||||
void *src_gdram,
|
||||
T_IDX *nram_offset,
|
||||
int transfer_size,
|
||||
int token_count) {
|
||||
if (token_count <= 0 || src_gdram == nullptr) return;
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T_IDX, uint32_t>::value) {
|
||||
GATHER_ASYNC_IO0(u32);
|
||||
} else {
|
||||
GATHER_ASYNC_IO0(u64);
|
||||
}
|
||||
#else
|
||||
for (int k = 0; k < token_count; k++) {
|
||||
__memcpy_async((int8_t *)dst + k * transfer_size,
|
||||
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
|
||||
int *nram_mask,
|
||||
uint8_t *nram_mask_char,
|
||||
int *nram_mask_buffer,
|
||||
int begin_expert_acc_tokens,
|
||||
int end_expert_acc_tokens,
|
||||
int token_count,
|
||||
bool expert_parallelism) {
|
||||
if (!expert_parallelism) {
|
||||
return token_count;
|
||||
}
|
||||
|
||||
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
|
||||
#if __BANG_ARCH__ >= 592
|
||||
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
|
||||
.ge(begin_expert_acc_tokens)
|
||||
.land(nram_mask_buffer)
|
||||
.cvt<float>(0);
|
||||
#else
|
||||
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
|
||||
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
|
||||
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
|
||||
#endif
|
||||
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
|
||||
int active_token_count = __bang_count((float *)nram_mask, token_count);
|
||||
return active_token_count;
|
||||
}
|
||||
|
||||
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
|
||||
int *nram_idx,
|
||||
uint64_t mul_scalar,
|
||||
int64_t add_scalar,
|
||||
uint32_t token_count) {
|
||||
#if __BANG_ARCH__ > 592
|
||||
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
|
||||
#else
|
||||
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
|
||||
#endif
|
||||
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
|
||||
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
|
||||
}
|
||||
|
||||
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
|
||||
int *nram_idx,
|
||||
uint32_t mul_scalar,
|
||||
int64_t add_scalar,
|
||||
uint32_t token_count) {
|
||||
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
|
||||
token_count);
|
||||
}
|
||||
|
||||
template <typename T_IDX>
|
||||
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
|
||||
T_IDX *nram_bias_offset,
|
||||
int *nram_token_idx,
|
||||
int *nram_expert_tables,
|
||||
int expert_num,
|
||||
int token_count,
|
||||
int active_token_count,
|
||||
int hidden_size,
|
||||
int local_hidden_begin,
|
||||
int dtype_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int begin_expert_acc_tokens,
|
||||
bool has_bias) {
|
||||
// for large tensor, convert int322int64 then do multiply and add seperately.
|
||||
if (active_token_count <= 0) return;
|
||||
if (has_bias) {
|
||||
int *nram_bias_offset_temp = (int *)nram_token_offset;
|
||||
__bang_write_zero(nram_bias_offset, active_token_count);
|
||||
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
|
||||
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
|
||||
active_token_count);
|
||||
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
|
||||
active_token_count);
|
||||
}
|
||||
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
|
||||
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
|
||||
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
|
||||
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
|
||||
active_token_count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
FUSE_MUL_CVT(bf16);
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
FUSE_MUL_CVT(f16);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
||||
}
|
||||
#else
|
||||
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
|
||||
floatTo((T *)dst, (float *)dst, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
|
||||
#if __BANG_ARCH__ > 500
|
||||
if (std::is_same<T, bfloat16_t>::value) {
|
||||
FUSE_MULADD_CVT(bf16);
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
FUSE_MULADD_CVT(f16);
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
||||
size);
|
||||
}
|
||||
#else
|
||||
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
|
||||
size);
|
||||
floatTo((T *)dst, (float *)dst, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
// weightedReduceSum with EP split
|
||||
// input [token_count, k, hidden_size], weight [token_count, k]
|
||||
// 1. input * weight
|
||||
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
||||
template <typename T,
|
||||
bool expert_parallelism,
|
||||
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
|
||||
__mlu_func__ void weightedReduceSum(T *output,
|
||||
T *input,
|
||||
float *weight,
|
||||
T *input_buffer,
|
||||
int8_t *og_mask,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int token_count,
|
||||
bool &is_ping) {
|
||||
float *nram_input_buffer =
|
||||
(float *)((half *)input_buffer +
|
||||
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
||||
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
||||
int32_t index[32];
|
||||
float reg_weight[128];
|
||||
int8_t *index_ = (int8_t *)index;
|
||||
int topk_divide_4 = PAD_UP(topk, 4) / 4;
|
||||
int token_use_count = 0;
|
||||
for (int t_i = 0; t_i < token_count; t_i++) {
|
||||
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
||||
for (int i = 0; i < topk_divide_4; i++) {
|
||||
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
|
||||
float *weight_begin = weight + t_i * topk + i * 4;
|
||||
reg_weight[i * 4] = __load_nram(weight_begin);
|
||||
if (i * 4 + 1 < topk) {
|
||||
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
|
||||
}
|
||||
if (i * 4 + 2 < topk) {
|
||||
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
|
||||
}
|
||||
if (i * 4 + 3 < topk) {
|
||||
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
|
||||
}
|
||||
}
|
||||
|
||||
int first_in_expert = 0;
|
||||
float expert_coeff;
|
||||
for (; first_in_expert < topk - 1; first_in_expert++) {
|
||||
bool in_expert_range = index_[first_in_expert];
|
||||
if (!in_expert_range) continue;
|
||||
|
||||
expert_coeff = reg_weight[first_in_expert];
|
||||
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
|
||||
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
||||
token_use_count++;
|
||||
break;
|
||||
}
|
||||
if (first_in_expert == topk - 1) {
|
||||
if (index_[topk - 1]) {
|
||||
expert_coeff = reg_weight[topk - 1];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
} else {
|
||||
__bang_write_zero((T *)output_begin, hidden_size);
|
||||
}
|
||||
} else {
|
||||
for (int j = first_in_expert + 1; j < topk - 1; j++) {
|
||||
bool in_expert_range = index_[j];
|
||||
if (!in_expert_range) continue;
|
||||
|
||||
expert_coeff = reg_weight[j];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
||||
hidden_size, hidden_size);
|
||||
}
|
||||
if (index_[topk - 1]) {
|
||||
expert_coeff = reg_weight[topk - 1];
|
||||
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
|
||||
token_use_count++;
|
||||
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
} else {
|
||||
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_ping && sizeof(T) < sizeof(float)) {
|
||||
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
||||
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0);
|
||||
}
|
||||
is_ping = !is_ping;
|
||||
}
|
||||
|
||||
// weightedReduceSum without EP split
|
||||
// input [token_count, k, hidden_size], weight [token_count, k]
|
||||
// 1. input * weight
|
||||
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
|
||||
template <typename T,
|
||||
bool expert_parallelism,
|
||||
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
|
||||
__mlu_func__ void weightedReduceSum(T *output,
|
||||
T *input,
|
||||
float *weight,
|
||||
T *input_buffer,
|
||||
int8_t *og_mask,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int token_count,
|
||||
bool &is_ping) {
|
||||
float *nram_input_buffer =
|
||||
(float *)((half *)input_buffer +
|
||||
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
|
||||
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
|
||||
if (topk == 1) {
|
||||
for (int i = 0; i < token_count; i++) {
|
||||
float expert_coeff = __load_nram(weight + i);
|
||||
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
|
||||
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (int t_i = 0; t_i < token_count; t_i++) {
|
||||
float *output_begin = (float *)(output_base + t_i * hidden_size);
|
||||
float expert_coeff = __load_nram(weight + t_i * topk);
|
||||
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
|
||||
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
|
||||
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
|
||||
expert_coeff = __load_nram(weight + t_i * topk + 1);
|
||||
for (int k_i = 2; k_i < topk; k_i++) {
|
||||
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
|
||||
hidden_size, hidden_size);
|
||||
expert_coeff = __load_nram(weight + t_i * topk + k_i);
|
||||
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
|
||||
}
|
||||
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
|
||||
}
|
||||
if (!is_ping && sizeof(T) < sizeof(float)) {
|
||||
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
|
||||
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
|
||||
token_count * hidden_size * sizeof(T), 0);
|
||||
}
|
||||
is_ping = !is_ping;
|
||||
}
|
||||
|
||||
template <typename T, typename T_IDX>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
|
||||
T *input,
|
||||
T *bias,
|
||||
T *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
|
||||
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
|
||||
int task_avg_tokens = num_token / taskDimY;
|
||||
int task_remain_tokens = num_token % taskDimY;
|
||||
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
|
||||
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
|
||||
if (local_hidden_size <= 0) return;
|
||||
if (task_tokens <= 0) return;
|
||||
|
||||
constexpr int int32_dtype_size = (int)sizeof(int);
|
||||
constexpr int fp32_dtype_size = (int)sizeof(float);
|
||||
int pad_num_expert = PAD_UP(num_expert + 1, 32);
|
||||
bool has_bias = bias != nullptr;
|
||||
bool has_residual = residual != nullptr;
|
||||
bool using_acc_sum = cusum_token_count != nullptr;
|
||||
bool expert_parallelism = expert_size < num_expert;
|
||||
|
||||
int block_size = TOKEN_BLOCK * topk;
|
||||
int pad_block_size = PAD_UP(block_size, 64);
|
||||
|
||||
int *nram_expert_tables = (int *)nram_buffer;
|
||||
int *nram_token_idx = nram_expert_tables + pad_num_expert;
|
||||
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
|
||||
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
|
||||
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
|
||||
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
|
||||
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
|
||||
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
|
||||
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
|
||||
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
|
||||
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
|
||||
float *nram_weight_ping =
|
||||
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
|
||||
float *nram_weight_pong = nram_weight_ping + pad_block_size;
|
||||
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
|
||||
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
|
||||
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
|
||||
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
|
||||
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
|
||||
int *nram_mask_buffer = (int *)nram_token_offset;
|
||||
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
|
||||
|
||||
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
|
||||
int begin_expert_acc_tokens = 0;
|
||||
int end_expert_acc_tokens = num_token * topk;
|
||||
if (using_acc_sum) {
|
||||
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
|
||||
init_token_count * sizeof(int), GDRAM2NRAM);
|
||||
__sync_io();
|
||||
|
||||
if (expert_parallelism) {
|
||||
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
|
||||
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
|
||||
}
|
||||
|
||||
int active_token_count = getMaskAndActiveTokenCount(
|
||||
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
||||
end_expert_acc_tokens, init_token_count, expert_parallelism);
|
||||
|
||||
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
|
||||
init_token_count, active_token_count, hidden_size, local_hidden_begin,
|
||||
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
|
||||
__sync_io_move_compute(true, false, false, false, false, true);
|
||||
__sync_io_move_compute(false, false, true, true, false, false);
|
||||
|
||||
int next_active_token_count = active_token_count;
|
||||
int previous_global_token_begin = 0;
|
||||
int previous_token_count = 0;
|
||||
bool is_ping = false;
|
||||
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
|
||||
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
|
||||
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
|
||||
bool is_last_loop = next_token_begin >= task_tokens;
|
||||
bool is_last_2_loop = next_next_token_begin >= task_tokens;
|
||||
int current_token_begin = task_begin * TOKEN_BLOCK;
|
||||
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
|
||||
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
|
||||
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
|
||||
int current_global_token_begin = task_token_begin + current_token_begin;
|
||||
int next_global_token_begin = task_token_begin + next_token_begin;
|
||||
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
|
||||
|
||||
if (!is_last_loop) {
|
||||
if (!is_last_2_loop) {
|
||||
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
|
||||
next_next_token_count * topk * sizeof(int), 0, 0, 0);
|
||||
}
|
||||
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
|
||||
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
|
||||
if (has_residual) {
|
||||
loadAsync2d(nram_residual_ping,
|
||||
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
|
||||
hidden_size * sizeof(T), next_token_count - 1);
|
||||
}
|
||||
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
|
||||
local_hidden_size * sizeof(T), next_active_token_count);
|
||||
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
|
||||
local_hidden_size * sizeof(T), next_active_token_count);
|
||||
}
|
||||
|
||||
if (task_begin >= 1) {
|
||||
storeAsync2d(
|
||||
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
||||
local_hidden_size * sizeof(T), previous_token_count - 1);
|
||||
}
|
||||
|
||||
if (task_begin >= 0) {
|
||||
if (has_bias && active_token_count) {
|
||||
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
|
||||
active_token_count * local_hidden_size);
|
||||
}
|
||||
if (expert_parallelism) {
|
||||
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
||||
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
||||
local_hidden_size, current_token_count, is_ping);
|
||||
} else {
|
||||
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
|
||||
nram_input_buffer, (int8_t *)nram_mask_char, topk,
|
||||
local_hidden_size, current_token_count, is_ping);
|
||||
}
|
||||
if (has_residual) {
|
||||
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
|
||||
current_token_count * local_hidden_size);
|
||||
}
|
||||
}
|
||||
|
||||
__sync_io_move_compute();
|
||||
active_token_count = next_active_token_count;
|
||||
if (expert_parallelism && !is_last_loop) {
|
||||
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
|
||||
}
|
||||
if (!is_last_2_loop) {
|
||||
next_active_token_count = getMaskAndActiveTokenCount(
|
||||
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
|
||||
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
|
||||
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
|
||||
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
|
||||
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
|
||||
begin_expert_acc_tokens, has_bias);
|
||||
}
|
||||
|
||||
swap(nram_input_ping, nram_input_pong);
|
||||
swap(nram_bias_ping, nram_bias_pong);
|
||||
swap(nram_residual_ping, nram_residual_pong);
|
||||
swap(nram_weight_ping, nram_weight_pong);
|
||||
swap(nram_output_ping, nram_output_pong);
|
||||
previous_global_token_begin = current_global_token_begin;
|
||||
previous_token_count = current_token_count;
|
||||
}
|
||||
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
|
||||
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
|
||||
local_hidden_size * sizeof(T), previous_token_count - 1);
|
||||
}
|
||||
|
||||
#if __BANG_ARCH__ < 500
|
||||
template <>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
|
||||
bfloat16_t *input,
|
||||
bfloat16_t *bias,
|
||||
bfloat16_t *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_ids,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {}
|
||||
|
||||
template <>
|
||||
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
|
||||
bfloat16_t *input,
|
||||
bfloat16_t *bias,
|
||||
bfloat16_t *residual,
|
||||
float *reduce_weight,
|
||||
int *cusum_token_count,
|
||||
int *gather_ids,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
int HIDDEN_BLOCK,
|
||||
int TOKEN_BLOCK) {}
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const void *residual,
|
||||
const float *reduce_weight,
|
||||
const int *cusum_token_count,
|
||||
const int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
cnnlDataType_t dtype) {
|
||||
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: "
|
||||
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if (bias != nullptr) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
|
||||
<< "cusum_token_count can not be nullptr.";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
size_t data_bytes = 0;
|
||||
cnnlGetSizeOfDataType(dtype, &data_bytes);
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
|
||||
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
|
||||
int convert_buffer = data_bytes == 2
|
||||
? 3 * hidden_size * data_bytes
|
||||
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
|
||||
int max_input_size = (432 * 1024 - convert_buffer) /
|
||||
(2 * topk * data_bytes + /*input size, double buffer*/
|
||||
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
|
||||
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
|
||||
2 * data_bytes); /*output size, one buffer*/
|
||||
|
||||
int TOKEN_BLOCK = 1;
|
||||
int HIDDEN_BLOCK = 1;
|
||||
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
|
||||
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
|
||||
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
|
||||
TOKEN_BLOCK = 1;
|
||||
} else {
|
||||
HIDDEN_BLOCK = hidden_size;
|
||||
}
|
||||
// for latency case, hidden_size is large but token is small.
|
||||
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
|
||||
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
|
||||
}
|
||||
|
||||
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
|
||||
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
|
||||
task_dim_x =
|
||||
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
|
||||
uint32_t pad_dim_x = task_dim_x;
|
||||
while (pad_dim_x <= cluster_num * core_num) {
|
||||
if ((cluster_num * core_num % pad_dim_x == 0)) {
|
||||
task_dim_x = pad_dim_x;
|
||||
break;
|
||||
}
|
||||
pad_dim_x += core_num;
|
||||
}
|
||||
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
|
||||
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
|
||||
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
|
||||
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
|
||||
}
|
||||
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
|
||||
|
||||
float max_cluster_num = core_num * cluster_num / task_dim_x;
|
||||
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
|
||||
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
|
||||
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
|
||||
|
||||
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
|
||||
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
|
||||
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (!is_large_tensor) {
|
||||
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
||||
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
||||
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
} else {
|
||||
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
|
||||
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
|
||||
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
|
||||
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
|
||||
}
|
||||
} else {
|
||||
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
|
||||
<< "among float/half/bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
85
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh
Normal file
85
torch_mlu_ops-v1.3.2/csrc/kernels/moe/combine_result.mluh
Normal file
@@ -0,0 +1,85 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Sort tokens grouped by different experts based on index. Each token
|
||||
* selects the topk hidden vectors, multiplies them by corresponding weights,
|
||||
* and finally reduces the topk vectors for each token. This process involves
|
||||
* bias and residual, calculated as (x + bias) * weight + residual.
|
||||
* @example
|
||||
* input:
|
||||
* [[[1, 2, 1, 1],
|
||||
* [1, 1, 1, 2]],
|
||||
* [[2, 1, 1, 1],
|
||||
* [1, 1, 1, 1]]]
|
||||
* num_token = 2, topk = 2
|
||||
* cusum_token_count = [0, 2, 4]
|
||||
* index:
|
||||
* [0, 1, 2, 3]
|
||||
* weight:
|
||||
* [0, 0, 1, 1]
|
||||
* bias:
|
||||
* [[0, 0, 0, 0],
|
||||
* [1, 1, 1, 1]]
|
||||
* residual:
|
||||
* [[1, 1, 1, 1],
|
||||
* [0, 0, 0, 0]]
|
||||
* output:
|
||||
* [[1, 1, 1, 1],
|
||||
* [5, 4, 4, 4]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param output: Output. Pointer to the MLU memory that stores the result.
|
||||
* The shape is [num_token, hidden_size].
|
||||
* @param input: Input. Pointer to the MLU memory that stores input tokens.
|
||||
* The shape is [num_token * topk, hidden_size].
|
||||
* @param bias: Input. Pointer to the MLU memory that stores bias.
|
||||
* The shape is [num_expert, hidden_size].
|
||||
* @param residual: Input. Pointer to the MLU memory that stores residual.
|
||||
* The shape is [num_token, hidden_size].
|
||||
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
|
||||
* The shape is [num_token * topk].
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
|
||||
* token number of each expert. The shape is [num_expert + 1].
|
||||
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
|
||||
* The shape is [num_token * topk].
|
||||
* @param num_token: The total number of tokens.
|
||||
* @param topk: The number of expert.
|
||||
* @param num_expert: The number of expert.
|
||||
* @param hidden_size: The size of lowest dimension.
|
||||
* @param start_expert_id: The id of the first processed expert.
|
||||
* @param expert_size: The number of processed experts.
|
||||
* @param dtype: Data type.
|
||||
* @note Currently does not support bias.
|
||||
*/
|
||||
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
|
||||
void *output,
|
||||
const void *input,
|
||||
const void *bias,
|
||||
const void *residual,
|
||||
const float *reduce_weight,
|
||||
const int *cusum_token_count,
|
||||
const int *gather_idx,
|
||||
int num_token,
|
||||
int topk,
|
||||
int num_expert,
|
||||
int hidden_size,
|
||||
int start_expert_id,
|
||||
int expert_size,
|
||||
cnnlDataType_t dtype);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
|
||||
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
@@ -0,0 +1,219 @@
|
||||
#include <bang_device_functions_extra.h>
|
||||
#include <mlu.h>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "expand_input.mluh"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#define RESERVED_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define MEMCPY_BURST_SIZE 128
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_func__ void ExpandInputKernel(void *output,
|
||||
void *input,
|
||||
int *index,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int num_index) {
|
||||
void *input_ptr = input;
|
||||
uint64_t input_size = data_size * num_token * hidden_size;
|
||||
// whether SRAM_BUFFER_SIZE can hold the input data
|
||||
// sram_enable is true ==> is_nram_output is true
|
||||
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
|
||||
if (sram_enable) {
|
||||
input_ptr = (void *)sram_buffer;
|
||||
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
|
||||
__sync_cluster();
|
||||
}
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
|
||||
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
|
||||
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
|
||||
uint32_t total_num = num_index;
|
||||
uint32_t base = total_num / maxTaskDim;
|
||||
uint32_t tail = total_num - base * maxTaskDim;
|
||||
if (taskId >= maxTaskDim) {
|
||||
return;
|
||||
}
|
||||
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
|
||||
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
|
||||
|
||||
// nram
|
||||
/*
|
||||
* first: compute offset: index[i] * data_size
|
||||
* second: gather data
|
||||
* -------------------------------------------------
|
||||
* addr || index/offset | output |
|
||||
* type || int32_t/T_offset | T |
|
||||
* num || n | n * hidden_size |
|
||||
* -------------------------------------------------
|
||||
*/
|
||||
|
||||
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
|
||||
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
|
||||
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
|
||||
uint32_t per_num =
|
||||
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
|
||||
|
||||
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
|
||||
int *index_base = index + batch_step;
|
||||
T_offset *nram_offset = (T_offset *)nram_buffer;
|
||||
int32_t *nram_index;
|
||||
if (std::is_same<T_offset, int64_t>::value) {
|
||||
nram_index = (int32_t *)nram_offset + per_num;
|
||||
} else {
|
||||
nram_index = (int32_t *)nram_offset;
|
||||
}
|
||||
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
|
||||
|
||||
uint32_t repeat = batch_per_core / per_num;
|
||||
uint32_t remain = batch_per_core - repeat * per_num;
|
||||
uint32_t deal_num = per_num;
|
||||
uint32_t is_remain = remain != 0 ? 1 : 0;
|
||||
|
||||
for (int32_t i = 0; i < repeat + is_remain; i++) {
|
||||
if (i == repeat) {
|
||||
deal_num = remain;
|
||||
}
|
||||
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
|
||||
int32_t *index_ptr = index_base + i * per_num;
|
||||
|
||||
// index -> offset
|
||||
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
|
||||
if (std::is_same<T_offset, uint64_t>::value) {
|
||||
#if __BANG_ARCH__ > 592
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
|
||||
#else
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
|
||||
#endif
|
||||
}
|
||||
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
|
||||
|
||||
// copy
|
||||
if (is_nram_output) {
|
||||
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
|
||||
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
|
||||
// GDRAM or SRAM -> NRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
|
||||
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
|
||||
dir);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
|
||||
} else {
|
||||
// GDRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input + offset_value;
|
||||
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
|
||||
hidden_size * data_size, GDRAM2GDRAM);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
|
||||
void *hidden_state,
|
||||
int *gather_idx,
|
||||
int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
int32_t num_index = num_token * topk;
|
||||
int *gather_start_idx = (int *)gather_idx;
|
||||
if (cusum_token_count != nullptr) {
|
||||
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
|
||||
*((int *)cusum_token_count + start_expert_id);
|
||||
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
|
||||
}
|
||||
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
|
||||
num_token, hidden_size, num_index);
|
||||
}
|
||||
// instantiate kernels
|
||||
#define INSTANTIATE_ONE(data_size, T_offset) \
|
||||
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
|
||||
void *, void *, int *, int *, int, int, int, int, int, int);
|
||||
|
||||
INSTANTIATE_ONE(1, uint32_t)
|
||||
INSTANTIATE_ONE(2, uint32_t)
|
||||
INSTANTIATE_ONE(4, uint32_t)
|
||||
INSTANTIATE_ONE(8, uint32_t)
|
||||
// large tensor
|
||||
INSTANTIATE_ONE(1, uint64_t)
|
||||
INSTANTIATE_ONE(2, uint64_t)
|
||||
INSTANTIATE_ONE(4, uint64_t)
|
||||
INSTANTIATE_ONE(8, uint64_t)
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
||||
void *expand_hidden_state,
|
||||
const void *hidden_state,
|
||||
const int *gather_idx,
|
||||
const int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
cnnlDataType_t data_type,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
size_t type_size = 0;
|
||||
cnnlGetSizeOfDataType(data_type, &type_size);
|
||||
|
||||
int max_cluster_num =
|
||||
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
|
||||
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
|
||||
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
|
||||
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
|
||||
|
||||
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
|
||||
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
|
||||
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
|
||||
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
|
||||
expert_count);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
81
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh
Normal file
81
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mluh
Normal file
@@ -0,0 +1,81 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
/**
|
||||
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
|
||||
* @example
|
||||
* hidden_state:
|
||||
* [[1, 2, 3, 4],
|
||||
* [5, 6, 7, 8],
|
||||
* [9, 10, 11, 12]]
|
||||
* gather_idx:
|
||||
* [[1, 0, 2, 2, 1, 0]]
|
||||
* cusum_token_count: NULL
|
||||
* num_token = 3
|
||||
* hidden_size = 4
|
||||
* topk = 2
|
||||
* expand_hidden_state:
|
||||
* [[5, 6, 7, 8],
|
||||
* [1, 2, 3, 4],
|
||||
* [9, 10, 11, 12],
|
||||
* [9, 10, 11, 12],
|
||||
* [5, 6, 7, 8],
|
||||
* [1, 2, 3, 4]]
|
||||
* @param queue: The queue for mlu.
|
||||
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
|
||||
* the shape must be [num_token, hidden_size].
|
||||
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
|
||||
* the shape must be [num_token * topk].
|
||||
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
|
||||
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
|
||||
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
|
||||
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
|
||||
* expand_hidden_state = hidden_state[index]
|
||||
* else:
|
||||
* index = gather_idx[:]
|
||||
* expand_hidden_state = hidden_state[index]
|
||||
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
|
||||
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
|
||||
* which num_index =
|
||||
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
|
||||
* the shape should be [num_token * topk, hidden_size].
|
||||
* @param num_token: the number of token.
|
||||
* @param hidden_size: the slice size.
|
||||
* @param topk: the number of topk.
|
||||
* @param data_type: Data type of hidden_state.
|
||||
* @param total_expert_num: the total number of expert.
|
||||
* @param start_expert_id: the first expert id.
|
||||
* @param expert_count: the number of experts currently being processed.
|
||||
*/
|
||||
|
||||
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
||||
void *expand_hidden_state,
|
||||
const void *hidden_state,
|
||||
const int *gather_idx,
|
||||
const int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
cnnlDataType_t data_type,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
|
||||
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
935
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mlu
Normal file
@@ -0,0 +1,935 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "gen_idx.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
namespace kernels {
|
||||
|
||||
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
|
||||
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
|
||||
#define ALIGN_16 (16)
|
||||
|
||||
#define EXPERT_AVG_COUNT_TEST (0)
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
|
||||
|
||||
// Generate integer sequence data from 0 to length-1
|
||||
__mlu_func__ void generateIntSeq(int *dst, int length) {
|
||||
int count = 64;
|
||||
__bang_move(dst, range, std::min(count, length) * sizeof(int));
|
||||
while (count < length) {
|
||||
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
|
||||
count *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
// genIdx Block kernel, use only 1 core to process
|
||||
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
/* NRAM space */
|
||||
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
|
||||
// --------------------------------------------------------------
|
||||
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
|
||||
// | combine_idx | expand_idx | | scatter_offset |
|
||||
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
|
||||
// --------------------------------------------------------------
|
||||
// ------------------------------
|
||||
// |token_count|token_count_presum|
|
||||
// | | |
|
||||
// | num_expert| num_expert |
|
||||
// ------------------------------
|
||||
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
// num align to 16, size align to 64B
|
||||
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
|
||||
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
|
||||
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
|
||||
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
|
||||
#if __BANG_ARCH__ >= 592
|
||||
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
|
||||
#endif
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
|
||||
|
||||
// Load current core input expert_id and generate int sequence
|
||||
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
__sync();
|
||||
|
||||
// Initialize sort idx offset
|
||||
uint32_t sorted_idx_offset = 0;
|
||||
// Initialize token count first presum with 0
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
|
||||
|
||||
// Loop on each expert, eq, count, filter index
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
|
||||
token_total_num);
|
||||
// Use filter to sort gen_idx, output with sorted_idx_offset
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, token_total_num);
|
||||
|
||||
sorted_idx_offset += cur_expert_count;
|
||||
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
|
||||
|
||||
// Compute cusum token count and store
|
||||
if (need_cusum_token_count) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
|
||||
}
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
__sync_compute();
|
||||
// Store token_count and cusum token count
|
||||
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
if (need_cusum_token_count) {
|
||||
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// Use sorted idx to generate gather idx for expand and combine
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_offset = sorted_idx mul_scalar sizeof(int);
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
token_total_num);
|
||||
#else
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
|
||||
#endif
|
||||
__sync_compute();
|
||||
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// scatter_async to NRAM
|
||||
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
// expand_idx = sorted_idx div(topk)
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
|
||||
|
||||
// Store expand_idx and combine_idx
|
||||
__sync_compute();
|
||||
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__sync_move();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
|
||||
token_total_num * sizeof(int), NRAM2GDRAM);
|
||||
#else
|
||||
// 370 directly scatter to GDRAM
|
||||
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Only MLU500 series support NRAM2SRAM scatter direction
|
||||
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
|
||||
seg_offset += 32768;
|
||||
}
|
||||
|
||||
if (seg_remain > 0) {
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Scatter sequence, transfer size is sizeof(int)
|
||||
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
|
||||
// When length larger than 65535(maximum segnum in bang_scatter),
|
||||
// and src/offset address should align to 64B
|
||||
int seg_repeat = length / 32768;
|
||||
int seg_remain = length % 32768;
|
||||
int seg_offset = 0;
|
||||
|
||||
for (int seg = 0; seg < seg_repeat; seg++) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
|
||||
#endif
|
||||
seg_offset += 32768;
|
||||
}
|
||||
if (seg_remain > 0) {
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
|
||||
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#else
|
||||
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
|
||||
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Get token count
|
||||
__mlu_func__ void getTokenCount(int *token_count,
|
||||
int *expert_id,
|
||||
int token_cur_core,
|
||||
int cur_token_start,
|
||||
int num_expert) {
|
||||
// 1. Partition on [num_token*topk],
|
||||
// each core for-loop on all expert_id, use eq and count instructions,
|
||||
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// ------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
|
||||
// | deal_num | deal_num | num_expert |
|
||||
// ------------------------------------------------------
|
||||
|
||||
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
|
||||
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
|
||||
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
|
||||
// Current core data loop
|
||||
uint32_t repeat = token_cur_core / deal_num;
|
||||
uint32_t remain = token_cur_core % deal_num;
|
||||
uint32_t total_repeat = repeat + (int)(remain > 0);
|
||||
uint32_t token_addr_offset = cur_token_start;
|
||||
|
||||
// Initialize token_count with 0
|
||||
if (taskId == 0) {
|
||||
__gdramset((int *)token_count, num_expert, 0);
|
||||
}
|
||||
// Sync for initialize token_count
|
||||
__sync_all_ipu();
|
||||
|
||||
// Initialize expert count onchip with 0
|
||||
if (token_cur_core > 0) {
|
||||
__bang_write_zero((int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
|
||||
// actual num in loop
|
||||
int cur_deal_num = deal_num;
|
||||
for (int i = 0; i < total_repeat; i++) {
|
||||
if (i == total_repeat - 1 && remain > 0) {
|
||||
cur_deal_num = remain;
|
||||
}
|
||||
// Load current core input expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
cur_deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += cur_deal_num;
|
||||
|
||||
// Loop on each expert, eq, count
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
|
||||
// NOTE: __bang_count() only support floating data type
|
||||
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
|
||||
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicAdd(reduce) all cores token count results
|
||||
if (token_cur_core > 0) {
|
||||
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
|
||||
}
|
||||
// Sync for all cores, get accumulate of token_count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 2. Get token count presum, for each expert index start address after sorting
|
||||
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const int num_expert) {
|
||||
// 2. After first process, already get token_count.
|
||||
// Then use one core to pre-sum on token_count, consider size of int32,
|
||||
// first expert id start address should be zero.
|
||||
// to get each expert id start address after sorting, store to workspace,
|
||||
// token_count_presum.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// load token_count to token_count_presum[1~num_expert+1],
|
||||
// for i = 0 to num_expert:
|
||||
// token_count_presum[i+1] += token_count_presum[i]
|
||||
// store token_count_presum[0~num_expert]
|
||||
// -------------------------
|
||||
// |token_count_presum_onchip|
|
||||
// | {0}, num_expert |
|
||||
// -------------------------
|
||||
|
||||
if (taskId == 0) {
|
||||
// Initialize count presum onchip with a first 0
|
||||
int8_t *token_count_presum_onchip = nram_buffer;
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
// Load token_count with an offset of 1
|
||||
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// Calculate presum of token count by each expert
|
||||
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
|
||||
((int *)token_count_presum_onchip)[cur_expert + 1] +=
|
||||
((int *)token_count_presum_onchip)[cur_expert];
|
||||
}
|
||||
|
||||
// Store token count presum to workspace
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
// Sync for all cores, get presum of token count
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
|
||||
int *token_count,
|
||||
const uint32_t token_total_num,
|
||||
const int num_expert) {
|
||||
uint32_t token_count_avg = token_total_num / num_expert;
|
||||
uint32_t expert_remain_num = token_total_num % num_expert;
|
||||
|
||||
int8_t *token_count_onchip = nram_buffer;
|
||||
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
|
||||
|
||||
((int *)token_count_presum_onchip)[0] = 0;
|
||||
|
||||
for (int i = 0; i < num_expert; i++) {
|
||||
((int *)token_count_onchip)[i] =
|
||||
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
|
||||
((int *)token_count_presum_onchip)[i + 1] =
|
||||
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
|
||||
}
|
||||
|
||||
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
|
||||
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
|
||||
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
__mlu_func__ void getSortedIdx(int *sorted_idx,
|
||||
int *expert_id,
|
||||
int *token_count_presum,
|
||||
const int token_total_num,
|
||||
const int num_expert,
|
||||
const int expert_cur_core,
|
||||
const int cur_expert_start,
|
||||
const int cur_expert_end) {
|
||||
// 3. Partition on num_expert, each core generate position index from 0,
|
||||
// and for-loop on all expert_id data, use eq with own each expert_id,
|
||||
// and filter on index, stores to each expert_id start address of
|
||||
// sorted_idx on workspace.
|
||||
// And sync for all cores.
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
// |expert_start_addr|
|
||||
// | num_expert |
|
||||
// -----------------
|
||||
|
||||
// Calculate new deal_num of sorting process
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
|
||||
|
||||
// Each core deal with whole token expert_id data
|
||||
int repeat = token_total_num / deal_num;
|
||||
int remain = token_total_num % deal_num;
|
||||
int token_addr_offset = 0;
|
||||
|
||||
int8_t *expert_id_onchip = nram_buffer;
|
||||
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
|
||||
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
|
||||
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// When num_expert < taskDim, not all cores need to sort
|
||||
if (expert_cur_core > 0) {
|
||||
// Generate position index from 0
|
||||
if (deal_num <= token_total_num) {
|
||||
generateIntSeq((int *)gen_idx_onchip, deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)gen_idx_onchip, token_total_num);
|
||||
}
|
||||
|
||||
// Initialize expert start address with presum of token count
|
||||
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core expert_id
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
token_addr_offset += deal_num;
|
||||
|
||||
// Loop for current core expert, eq, filter position index
|
||||
// use filter, store to sorted_idx[expert_start_addr]
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, deal_num);
|
||||
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
|
||||
// Update address offset of current expert
|
||||
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
|
||||
}
|
||||
}
|
||||
|
||||
// Update position index for each data loop
|
||||
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
|
||||
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
|
||||
|
||||
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
|
||||
|
||||
// NOTE: __bang_filter() only support floating data type
|
||||
uint32_t cur_expert_count =
|
||||
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
|
||||
(float *)cur_expert_result, remain);
|
||||
// Store to the corresponding address of sorted_idx
|
||||
if (cur_expert_count > 0) {
|
||||
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
|
||||
cur_expert_count * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sync for all cores, get position index after sorting
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
template <bool is_sram_scatter>
|
||||
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4. Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------------------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num | deal_num | deal_num |
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
// scatter dst GDRAM addr should align to 64B
|
||||
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
|
||||
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
deal_num);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
|
||||
}
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, deal_num);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
if (is_sram_scatter) {
|
||||
// if scatter to SRAM, need to sync compute with mv
|
||||
__sync_move();
|
||||
}
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
if (is_sram_scatter) {
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
|
||||
remain);
|
||||
} else {
|
||||
// GDRAM addr should align to 64B
|
||||
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
|
||||
combine_idx_align_offset, (int)(sizeof(int)), remain);
|
||||
}
|
||||
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
if (is_sram_scatter) {
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
remain);
|
||||
} else {
|
||||
// Scatter to output gather_combine_idx
|
||||
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
|
||||
(uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.1 Get gather combine index on SRAM
|
||||
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start) {
|
||||
// 4.1 Partition on [num_token*topk], with only 1 union
|
||||
// load sorted_idx onchip,
|
||||
// generate sequence according to position index from 0, add token offset
|
||||
// gather_combine_idx = scatter(seq, sorted_idx)
|
||||
// update sequence
|
||||
// NRAM:
|
||||
// -------------------------------
|
||||
// |scatter_offset|scatter_sequence|
|
||||
// | deal_num | deal_num |
|
||||
// -------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *scatter_offset = nram_buffer;
|
||||
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
|
||||
|
||||
// Generate position index from 0
|
||||
// Add base offset to sequence according to current core token start address
|
||||
if (token_cur_core > 0) {
|
||||
if (deal_num <= token_cur_core) {
|
||||
generateIntSeq((int *)scatter_sequence, deal_num);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
deal_num);
|
||||
} else { // only remainder part
|
||||
generateIntSeq((int *)scatter_sequence, token_cur_core);
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
|
||||
token_cur_core);
|
||||
}
|
||||
}
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
// Scatter to SRAM
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
|
||||
deal_num);
|
||||
__sync_move();
|
||||
|
||||
// Add offset to sequence and token_address
|
||||
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
|
||||
GDRAM2NRAM);
|
||||
|
||||
// offset = sorted_idx * sizeof(int), counted in bytes
|
||||
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
|
||||
// Sync for scatter
|
||||
__sync_compute();
|
||||
|
||||
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
|
||||
}
|
||||
}
|
||||
|
||||
// 4.2 Get gather expand index
|
||||
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
|
||||
int *sorted_idx,
|
||||
const int token_cur_core,
|
||||
const int cur_token_start,
|
||||
const int topk) {
|
||||
// 4.2 Partition on [num_token*topk],
|
||||
// load sorted_idx onchip,
|
||||
// gather_expand_idx = sorted_idx / topk
|
||||
// NRAM:
|
||||
// -----------------------------------
|
||||
// |sorted_idx_onchip|expand_idx_onchip|
|
||||
// | deal_num | deal_num |
|
||||
// -----------------------------------
|
||||
|
||||
// Calculate new deal_num of generate gather index
|
||||
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
|
||||
int repeat = token_cur_core / deal_num;
|
||||
int remain = token_cur_core % deal_num;
|
||||
int token_addr_offset = cur_token_start;
|
||||
|
||||
int8_t *sorted_idx_onchip = nram_buffer;
|
||||
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
|
||||
|
||||
// repeat part
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
deal_num * sizeof(int), GDRAM2NRAM);
|
||||
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
deal_num * sizeof(int), NRAM2GDRAM);
|
||||
token_addr_offset += deal_num;
|
||||
}
|
||||
|
||||
// remainder part
|
||||
if (remain > 0) {
|
||||
// Load current core sorted_idx
|
||||
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
|
||||
remain * sizeof(int), GDRAM2NRAM);
|
||||
// expand_idx_onchip = sorted_idx / topk
|
||||
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
|
||||
// Store expand idx
|
||||
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
|
||||
remain * sizeof(int), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
// Store token count presum result, shape [num_expert + 1]
|
||||
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
|
||||
// Store position index after sorting, shape [num_token*topk]
|
||||
int *sorted_idx = ((int *)workspace) + num_expert + 1;
|
||||
|
||||
// Calculate partition information for different processes
|
||||
// Partition on [num_token*topk]
|
||||
uint32_t token_total_num = num_token * topk;
|
||||
uint32_t token_cur_core = token_total_num / taskDim;
|
||||
uint32_t token_remain_num = token_total_num % taskDim;
|
||||
token_cur_core += (uint32_t)(taskId < token_remain_num);
|
||||
// Current core range according to partition on [num_token*topk]
|
||||
uint32_t cur_token_start = (taskId < token_remain_num)
|
||||
? token_cur_core * taskId
|
||||
: token_cur_core * taskId + token_remain_num;
|
||||
|
||||
// Partition on [num_expert]
|
||||
uint32_t expert_cur_core = num_expert / taskDim;
|
||||
uint32_t expert_remain_num = num_expert % taskDim;
|
||||
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
|
||||
// Current core range according to partition on [num_expert]
|
||||
uint32_t cur_expert_start = (taskId < expert_remain_num)
|
||||
? expert_cur_core * taskId
|
||||
: expert_cur_core * taskId + expert_remain_num;
|
||||
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
|
||||
|
||||
// Use Union1 SRAM to scatter, only MLU500 series support now
|
||||
#if __BANG_ARCH__ >= 592
|
||||
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
|
||||
#else
|
||||
bool is_sram_scatter = false;
|
||||
#endif
|
||||
|
||||
if (__is_ipu()) {
|
||||
// 1. Get token count
|
||||
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
|
||||
num_expert);
|
||||
// 2. Get presum of token count
|
||||
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
|
||||
|
||||
// 3. Get expert position index after sorting
|
||||
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
|
||||
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
|
||||
}
|
||||
|
||||
#if EXPERT_AVG_COUNT_TEST
|
||||
// NOTE: test avg expert code here:
|
||||
if (__is_ipu() && taskId == 0) {
|
||||
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
|
||||
num_expert);
|
||||
}
|
||||
__sync_cluster();
|
||||
#endif
|
||||
|
||||
// 4. Get gather index for expand and combine
|
||||
if (is_sram_scatter) {
|
||||
// Only use Union1 SRAM
|
||||
uint32_t scatter_idx_cur_core = token_total_num / 4;
|
||||
uint32_t scatter_idx_remain_num = token_total_num % 4;
|
||||
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
|
||||
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
|
||||
? scatter_idx_cur_core * taskId
|
||||
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
|
||||
|
||||
// Only Union1 task type,
|
||||
// deal once num is same with deal_num in getGatherIdx,
|
||||
// which means only 1 repeat to generate both expand and combine idx on NRAM
|
||||
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
|
||||
if (taskDim <= 4 || token_total_num < deal_once_num) {
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
scatter_idx_cur_core, cur_idx_start, topk);
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
} else {
|
||||
// sync for ipu and mpu
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If taskDim > 4, use first union to generate combine idx,
|
||||
// use other union to generate expand idx
|
||||
if (taskId < 4) {
|
||||
if (__is_ipu()) {
|
||||
// Scatter combine idx to SRAM
|
||||
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
|
||||
__sync_cluster();
|
||||
} else {
|
||||
__sync_cluster();
|
||||
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
|
||||
token_total_num * sizeof(int), SRAM2GDRAM);
|
||||
}
|
||||
} else {
|
||||
// Other union generate expand idx
|
||||
if (__is_ipu()) {
|
||||
uint32_t expand_dim = taskDim - 4;
|
||||
uint32_t expand_id = taskId - 4;
|
||||
uint32_t expand_token_cur_core = token_total_num / expand_dim;
|
||||
uint32_t expand_token_remain_num = token_total_num % expand_dim;
|
||||
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
|
||||
|
||||
uint32_t expand_cur_token_start =
|
||||
(expand_id < expand_token_remain_num)
|
||||
? expand_token_cur_core * expand_id
|
||||
: expand_token_cur_core * expand_id + expand_token_remain_num;
|
||||
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
|
||||
expand_cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// not use SRAM to generate both expand and combine idx
|
||||
if (__is_ipu()) {
|
||||
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
|
||||
token_cur_core, cur_token_start, topk);
|
||||
}
|
||||
}
|
||||
|
||||
// step 5 does not need MPU
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
} // end of kernel
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
const int token_total_num = num_token * topk;
|
||||
|
||||
// For partition on num_token*topk, single core processes at least 128 num
|
||||
const int single_core_num_limit = 1024;
|
||||
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
|
||||
// When partition on num_expert, each core at least processes one expert
|
||||
need_core_num = std::max(num_expert, need_core_num);
|
||||
|
||||
// When consider UnionX cnrt func type, reset cluster_num
|
||||
if (token_total_num <= 4096) { // Block
|
||||
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
// Block kernel does not need workspace
|
||||
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
|
||||
num_expert, topk);
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
} else if (need_core_num <= 4) { // Union1
|
||||
cluster_num = 1;
|
||||
} else if (need_core_num <= 8) { // Union2
|
||||
cluster_num = std::min(cluster_num, 2);
|
||||
} else if (need_core_num <= 16) { // Union4
|
||||
cluster_num = std::min(cluster_num, 4);
|
||||
} else if (need_core_num <= 32) { // Union8
|
||||
cluster_num = std::min(cluster_num, 8);
|
||||
}
|
||||
|
||||
cnrtFunctionType_t k_type;
|
||||
cnrtDim3_t k_dim{1, 1, 1};
|
||||
|
||||
// Find max UnionX cnrt func type
|
||||
if (cluster_num == 1) {
|
||||
k_type = cnrtFuncTypeUnion1;
|
||||
k_dim.x = 4;
|
||||
} else if (cluster_num < 4) { // cluster num is 2 or 3
|
||||
k_type = cnrtFuncTypeUnion2;
|
||||
k_dim.x = 8;
|
||||
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
|
||||
k_type = cnrtFuncTypeUnion4;
|
||||
k_dim.x = 16;
|
||||
} else { // cluster num larger than 8
|
||||
k_type = cnrtFuncTypeUnion8;
|
||||
k_dim.x = 32;
|
||||
}
|
||||
|
||||
// The expert_id is int data type
|
||||
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
|
||||
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
|
||||
num_token, num_expert, topk);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
#undef EXPERT_AVG_COUNT_TEST // undef test macro
|
||||
|
||||
} // namespace tmo
|
||||
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
58
torch_mlu_ops-v1.3.2/csrc/kernels/moe/gen_idx.mluh
Normal file
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
|
||||
#include <vector>
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Apply generate MOE index operation, which performs the following
|
||||
* tasks:
|
||||
* - 1. Generate gather_expand_idx and gather_combine_idx.
|
||||
* - 2. Output token_count, the token number of each expert.
|
||||
* - 3. Prepare inputs and outputs address for group_gemm.
|
||||
* @param queue: The queue of mlu.
|
||||
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for expand hidden state operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
|
||||
* gather index for combine MOE operation, the shape must be
|
||||
* [num_token * topk].
|
||||
* @param token_count: Output. Pointer to the MLU memory that stores the token
|
||||
* number of each expert, the shape must be [num_expert].
|
||||
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
|
||||
* cumulative sum of the token number of each expert, the shape must be
|
||||
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
|
||||
* @param workspace: Input. A pointer to the extra workspace required in the
|
||||
* operation, the size must be larger than
|
||||
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
|
||||
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
|
||||
* of each token, the shape must be [num_token, topk].
|
||||
* @param num_token: The number of tokens.
|
||||
* @param num_expert: The number of experts.
|
||||
* @param topk: The number of expert selected by each token.
|
||||
*/
|
||||
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
|
||||
int *gather_expand_idx,
|
||||
int *gather_combine_idx,
|
||||
int *token_count,
|
||||
int *cusum_token_count,
|
||||
void *workspace,
|
||||
const void *expert_id,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int topk);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_
|
||||
21
torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh
Normal file
21
torch_mlu_ops-v1.3.2/csrc/kernels/moe/moe.mluh
Normal file
@@ -0,0 +1,21 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
#define CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
|
||||
#include "add_bias_activation.mluh"
|
||||
#include "combine_result.mluh"
|
||||
#include "expand_input.mluh"
|
||||
#include "gen_idx.mluh"
|
||||
#include "softmax_topk.mluh"
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_MOE_MLUH_
|
||||
602
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu
Normal file
602
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mlu
Normal file
@@ -0,0 +1,602 @@
|
||||
#include <mlu.h>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "softmax_topk.mluh"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define SCATTER_ALIGN (64) // align for __scatter()
|
||||
|
||||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
#define TILING_ALIGN (64)
|
||||
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
|
||||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
|
||||
|
||||
#define __TRANS_TILING(TYPE, CONVERT) \
|
||||
__asm__ volatile("trans.tiling." TYPE \
|
||||
" [%[dst]], [%[src]]," \
|
||||
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
|
||||
"%[is4], %[in5], %[is5]," \
|
||||
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
|
||||
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
|
||||
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
|
||||
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
|
||||
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
|
||||
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
|
||||
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
|
||||
|
||||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||||
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
|
||||
const SRC_DTYPE *src,
|
||||
const uint32_t in0,
|
||||
const uint32_t in1,
|
||||
const uint32_t is1,
|
||||
const uint32_t in2,
|
||||
const uint32_t is2,
|
||||
const uint32_t in3,
|
||||
const uint32_t is3,
|
||||
const uint32_t in4,
|
||||
const uint32_t is4,
|
||||
const uint32_t in5,
|
||||
const uint32_t is5,
|
||||
const uint32_t dn0,
|
||||
const uint32_t dn1,
|
||||
const uint32_t ds1,
|
||||
const uint32_t dn2,
|
||||
const uint32_t ds2,
|
||||
const uint32_t dn3,
|
||||
const uint32_t ds3,
|
||||
const uint32_t dn4,
|
||||
const uint32_t ds4,
|
||||
const uint32_t dn5,
|
||||
const uint32_t ds5) {
|
||||
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
|
||||
if (std::is_same<SRC_DTYPE, float>::value) {
|
||||
__TRANS_TILING("nram.sram.b32", ";")
|
||||
} else if (std::is_same<SRC_DTYPE, half>::value) {
|
||||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
|
||||
#if __BANG_ARCH__ >= 500
|
||||
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
|
||||
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* 将shape为[h,w]的数据转置为[w,h](带转数),分4块分别进行处理。
|
||||
* dst: dst地址
|
||||
* src: src地址
|
||||
* h: h方向大小
|
||||
* w: w方向大小
|
||||
*/
|
||||
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
|
||||
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
|
||||
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
|
||||
uint32_t w_align = w / align_num;
|
||||
uint32_t w_rem = w % align_num;
|
||||
uint32_t h_align = h / align_num;
|
||||
uint32_t h_rem = h % align_num;
|
||||
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
|
||||
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
|
||||
uint32_t in3 = w_align, is3 = TILING_ALIGN;
|
||||
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
|
||||
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
|
||||
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
|
||||
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
|
||||
/* 1. h_align * w_align */
|
||||
if (w_align > 0 && h_align > 0) {
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
|
||||
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 2. h_align * w_rem */
|
||||
if (w_rem > 0 && h_align > 0) {
|
||||
SRC_DTYPE *src_temp = src + w_align * align_num;
|
||||
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
|
||||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||||
dn0 = TILING_ALIGN;
|
||||
in1 = align_num;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
in4 = h_align;
|
||||
is4 = w * TILING_ALIGN;
|
||||
dn1 = w_rem;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
dn4 = in4;
|
||||
ds4 = align_num * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 3. h_rem * w_align */
|
||||
if (w_align > 0 && h_rem > 0) {
|
||||
SRC_DTYPE *src_temp = src + h_align * align_num * w;
|
||||
DST_DTYPE *dst_temp = dst + h_align * align_num;
|
||||
in0 = TILING_ALIGN;
|
||||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||||
in1 = h_rem;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
in4 = w_align;
|
||||
is4 = TILING_ALIGN;
|
||||
dn1 = align_num;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
dn4 = in4;
|
||||
ds4 = h * align_num * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
|
||||
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
|
||||
}
|
||||
/* 4. h_rem * w_rem */
|
||||
if (w_rem > 0 && h_rem > 0) {
|
||||
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
|
||||
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
|
||||
in0 = w_rem * sizeof(SRC_DTYPE);
|
||||
dn0 = h_rem * sizeof(SRC_DTYPE);
|
||||
in1 = h_rem;
|
||||
is1 = w * sizeof(SRC_DTYPE);
|
||||
dn1 = w_rem;
|
||||
ds1 = h * sizeof(DST_DTYPE);
|
||||
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
|
||||
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_func__ void getTopk(float *value_buffer,
|
||||
uint32_t *index_buffer,
|
||||
float *src_buffer,
|
||||
float *compute_buffer,
|
||||
float *max_buffer,
|
||||
float *temp_buffer,
|
||||
uint32_t *i_buffer,
|
||||
uint32_t *col_buffer,
|
||||
uint32_t topk,
|
||||
uint32_t num_expert_group,
|
||||
uint32_t col,
|
||||
uint32_t row,
|
||||
uint32_t value_index_stride,
|
||||
uint32_t group_size,
|
||||
bool is_deal_group) {
|
||||
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
|
||||
for (int k = 0; k < topk; k++) {
|
||||
if (is_deal_group) {
|
||||
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
|
||||
1, num_expert_group, 1, 1);
|
||||
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
|
||||
col);
|
||||
} else {
|
||||
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
|
||||
value_index_stride);
|
||||
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
|
||||
}
|
||||
#if __BANG_ARCH__ >= 592
|
||||
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
|
||||
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
|
||||
col); // replace max value with -inf
|
||||
#else
|
||||
for (int i = 0; i < col; i++) {
|
||||
uint32_t index = __load_nram(col_buffer + i);
|
||||
max_buffer[index] = -INFINITY;
|
||||
}
|
||||
#endif
|
||||
#if __BANG_ARCH__ < 500
|
||||
if (is_deal_group) {
|
||||
for (int i = 0; i < col; i++) {
|
||||
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
|
||||
__memcpy(compute_buffer + i * row + index * group_size,
|
||||
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#if __BANG_ARCH__ >= 592
|
||||
if (is_deal_group) {
|
||||
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
|
||||
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
|
||||
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
|
||||
topk - 1);
|
||||
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
|
||||
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
|
||||
(uint32_t *)compute_buffer, col * topk, col * topk);
|
||||
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
|
||||
NRAM2NRAM, group_size * sizeof(float), col * topk);
|
||||
__bang_write_value(src_buffer, row * col, -INFINITY);
|
||||
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
|
||||
group_size * sizeof(float), col * topk);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
|
||||
T *load_buffer,
|
||||
float *src_buffer,
|
||||
float *compute_buffer,
|
||||
float *group_max_buffer,
|
||||
float *nramout_value,
|
||||
uint32_t *nramout_index,
|
||||
uint32_t *i_buffer,
|
||||
uint32_t *col_buffer,
|
||||
float *softmax_buffer,
|
||||
uint32_t row,
|
||||
uint32_t nram_compute_col_num,
|
||||
uint32_t mask_num,
|
||||
uint32_t nram_max_col_num,
|
||||
uint32_t topk,
|
||||
int num_expert_group,
|
||||
uint32_t topk_group,
|
||||
uint32_t top_num,
|
||||
uint32_t nram_col_offset,
|
||||
int normalize_mode,
|
||||
bool valid_mask,
|
||||
bool split_mask) {
|
||||
uint32_t nram_compute_num = nram_compute_col_num * row;
|
||||
// convert to float for half/bf16 datatype
|
||||
if (std::is_same<T, half>::value) {
|
||||
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
|
||||
}
|
||||
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
|
||||
// compute softmax
|
||||
int tmp = 0x3fb8aa3b;
|
||||
float log2e = *(float *)&tmp; // for exp
|
||||
// src_buffer reuse as buffer for max/sum.
|
||||
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
|
||||
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
|
||||
nram_compute_col_num);
|
||||
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
|
||||
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
|
||||
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
|
||||
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
|
||||
nram_compute_col_num);
|
||||
__sync_cluster();
|
||||
// move mask and compute
|
||||
if (valid_mask) {
|
||||
if (!split_mask) {
|
||||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||||
if (std::is_same<T, half>::value) {
|
||||
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
|
||||
SRAM2NRAM);
|
||||
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
|
||||
mask_num * row);
|
||||
} else if (std::is_same<T, bfloat16_t>::value) {
|
||||
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
|
||||
mask_num * row * sizeof(T), SRAM2NRAM);
|
||||
__bang_bfloat162float((float *)compute_buffer,
|
||||
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
|
||||
} else {
|
||||
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
|
||||
}
|
||||
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
|
||||
mask_num * row);
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
} else {
|
||||
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
|
||||
nram_compute_col_num, row);
|
||||
__sync();
|
||||
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
|
||||
}
|
||||
}
|
||||
if (normalize_mode == 2) {
|
||||
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
|
||||
}
|
||||
|
||||
if (num_expert_group <= 1) {
|
||||
// num_expert_group <= 1, maintain original topk calculation logic
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * topk * sizeof(float), 0, false);
|
||||
} else {
|
||||
// num_expert_group > 1, use grouped_topk calculation logic
|
||||
uint32_t group_size = row / num_expert_group;
|
||||
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
|
||||
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
|
||||
group_size, 1, group_size, 1, 1);
|
||||
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
|
||||
// get topk_group
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
|
||||
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
|
||||
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
|
||||
// get topk
|
||||
#if __BANG_ARCH__ < 500
|
||||
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||||
#else
|
||||
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
|
||||
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
|
||||
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
|
||||
nram_max_col_num * top_num * sizeof(float), 0, false);
|
||||
#endif
|
||||
} // end else
|
||||
|
||||
// normalize result
|
||||
if (normalize_mode == 1) {
|
||||
// compute_buffer reuse as buffer for sum.
|
||||
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
|
||||
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
|
||||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||||
nram_compute_col_num);
|
||||
} else if (normalize_mode == 2) {
|
||||
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
|
||||
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
|
||||
nram_compute_col_num);
|
||||
}
|
||||
|
||||
// transpose back. src and dst of transpose can not be the same address.
|
||||
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
|
||||
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
|
||||
T *mask,
|
||||
int *index_out,
|
||||
float *value_out,
|
||||
int col,
|
||||
int row,
|
||||
int mask_num,
|
||||
int topk,
|
||||
int num_expert_group,
|
||||
int topk_group,
|
||||
int normalize_mode) {
|
||||
bool valid_mask = (mask != nullptr);
|
||||
int top_num = topk >= topk_group ? topk : topk_group;
|
||||
uint32_t nram_low_space =
|
||||
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
|
||||
SCATTER_ALIGN);
|
||||
if (num_expert_group <= 1) {
|
||||
nram_low_space =
|
||||
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
|
||||
}
|
||||
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
|
||||
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
|
||||
nram_max_col_num = col / taskDim + (col % taskDim > 0);
|
||||
}
|
||||
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
|
||||
if (nram_max_col_num <= 0) {
|
||||
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
|
||||
}
|
||||
uint32_t nram_deal_num = nram_max_col_num * row;
|
||||
uint32_t batch = col / mask_num;
|
||||
|
||||
// nram split:
|
||||
// |--------------------------|--------------------------|--------------------|...
|
||||
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
|
||||
// | src_buffer | compute_buffer | group_max_buffer |...
|
||||
// |--------------------------|--------------------------|--------------------|...
|
||||
|
||||
// |----------------------------------------|---------------|--------------|
|
||||
// | nram_col_num*3 | col*topk | col*topk |
|
||||
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
|
||||
// |----------------------------------------|---------------|--------------|
|
||||
float *src_buffer = (float *)nram_buffer;
|
||||
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
|
||||
float *group_max_buffer = compute_buffer + nram_deal_num;
|
||||
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
|
||||
if (num_expert_group <= 1) {
|
||||
i_buffer = (uint32_t *)group_max_buffer;
|
||||
}
|
||||
uint32_t *col_buffer = i_buffer + nram_max_col_num;
|
||||
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
|
||||
if (normalize_mode != 2) {
|
||||
softmax_buffer = (float *)col_buffer;
|
||||
}
|
||||
|
||||
float *nramout_value = softmax_buffer + nram_max_col_num;
|
||||
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
|
||||
if (num_expert_group <= 1) {
|
||||
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
|
||||
}
|
||||
T *load_buffer = (T *)src_buffer;
|
||||
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
|
||||
load_buffer = load_buffer + nram_deal_num;
|
||||
}
|
||||
|
||||
// set i_buffer
|
||||
for (uint32_t i = 0; i < nram_max_col_num; i++) {
|
||||
i_buffer[i] = i;
|
||||
}
|
||||
|
||||
// input[batch, mask, low], mask[mask, low]
|
||||
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
|
||||
bool split_mask = false;
|
||||
uint32_t batch_seg = nram_max_col_num / mask_num;
|
||||
uint32_t batch_rem = batch % batch_seg;
|
||||
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
|
||||
int repeat = DIV_UP(batch_seg_num, taskDim);
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
uint32_t seg_id = i * taskDim + taskId;
|
||||
uint32_t sram_load_num = mask_num * row;
|
||||
uint32_t sram_load_offset = 0;
|
||||
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
|
||||
? batch_rem * mask_num
|
||||
: batch_seg * mask_num;
|
||||
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
|
||||
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
|
||||
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
|
||||
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
|
||||
|
||||
// Load
|
||||
if (valid_mask) {
|
||||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
if (nram_load_num > 0) {
|
||||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
|
||||
// Compute
|
||||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||||
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
|
||||
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
|
||||
valid_mask, split_mask);
|
||||
|
||||
// Store
|
||||
if (nram_store_num > 0) {
|
||||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||||
NRAM2GDRAM);
|
||||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
} else {
|
||||
bool split_mask = true;
|
||||
uint32_t mask_seg = nram_max_col_num;
|
||||
uint32_t mask_rem = mask_num % mask_seg;
|
||||
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
|
||||
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
|
||||
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
|
||||
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
|
||||
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
|
||||
uint32_t batch_idx = i / sram_mask_seg_num;
|
||||
uint32_t mask_idx = i % sram_mask_seg_num;
|
||||
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
|
||||
uint32_t sram_load_num = sram_deal_mask_num * row;
|
||||
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
|
||||
? mask_idx * (sram_average_mask_num + 1)
|
||||
: mask_idx * sram_average_mask_num + sram_mask_rem;
|
||||
uint32_t sram_load_offset = sram_mask_offset * row;
|
||||
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
|
||||
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
|
||||
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
|
||||
uint32_t nram_load_num = nram_deal_mask_num * row;
|
||||
uint32_t nram_col_offset = taskIdX < nram_mask_rem
|
||||
? taskIdX * (nram_average_mask_num + 1)
|
||||
: taskIdX * nram_average_mask_num + nram_mask_rem;
|
||||
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
|
||||
uint32_t nram_store_num = nram_deal_mask_num * topk;
|
||||
uint32_t nram_store_offset =
|
||||
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
|
||||
// Load
|
||||
if (valid_mask) {
|
||||
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
|
||||
}
|
||||
if (nram_load_num > 0) {
|
||||
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
|
||||
// Compute
|
||||
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
|
||||
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
|
||||
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
|
||||
topk, num_expert_group, topk_group, top_num, nram_col_offset,
|
||||
normalize_mode, valid_mask, split_mask);
|
||||
|
||||
// Store
|
||||
if (nram_store_num > 0) {
|
||||
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
|
||||
NRAM2GDRAM);
|
||||
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__sync_cluster();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
||||
float *reduce_weight,
|
||||
int *expert_id,
|
||||
const void *input,
|
||||
const void *mask,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int num_mask,
|
||||
const int topk,
|
||||
const int num_expert_group,
|
||||
const int topk_group,
|
||||
const cnnlDataType_t dtype,
|
||||
const int normalize_mode) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
int top_num = topk >= topk_group ? topk : topk_group;
|
||||
if (num_expert_group <= 1) {
|
||||
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||||
<< "Supported max num_expert:"
|
||||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||||
<< ". Current num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
} else {
|
||||
if (num_expert >
|
||||
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
|
||||
<< "Supported max num_expert:"
|
||||
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
|
||||
<< ". Current num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
if (topk > num_expert) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
|
||||
<< "topk:" << topk << ". num_expert:" << num_expert;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (num_expert_group > 1) {
|
||||
if (mask != nullptr) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
|
||||
}
|
||||
if (num_expert % num_expert_group != 0) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
|
||||
<< "divisible by num_expert_group, but now num_expert:" << num_expert
|
||||
<< ", num_expert_group:" << num_expert_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (topk_group <= 0 || topk_group > num_expert_group) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
|
||||
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
|
||||
<< topk_group << ", num_expert group:" << num_expert_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
if (topk > (num_expert / num_expert_group) * topk_group) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
|
||||
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
|
||||
<< "topk :" << topk << ", num_expert:" << num_expert
|
||||
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||||
topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
|
||||
topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
|
||||
<< std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
|
||||
num_mask, topk, num_expert_group, topk_group, normalize_mode);
|
||||
} else {
|
||||
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
66
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh
Normal file
66
torch_mlu_ops-v1.3.2/csrc/kernels/moe/softmax_topk.mluh
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
|
||||
#include "../kernel_utils.h"
|
||||
#include "cnnl.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief Execute MOE Softmax Top-K Kernel.
|
||||
*
|
||||
* This function executes the MOE Softmax Top-K Kernel, which computes
|
||||
* the Top-K values along a specified dimension after applying softmax to the input data.
|
||||
* It is specifically designed for reduction along the lowest dimension.
|
||||
*
|
||||
* @param queue CNRT queue used to specify the queue for execution.
|
||||
* @param reduce_weight Pointer to store the Top-K values.
|
||||
* The shape must be [num_token, topk].
|
||||
* @param expert_id Pointer to store the indices of the Top-K values.
|
||||
* The shape must be [num_token, topk].
|
||||
* @param input Pointer to the input data containing the values to be computed.
|
||||
* The shape must be [num_token, num_expert].
|
||||
* @param mask Pointer to the input data containing the mask value to be computed after
|
||||
* computing softmax, Mask can be nullptr, which means no need to compute,
|
||||
* otherwise the shape and datatype of mask should be the same as input.
|
||||
* @param num_token Number of channels in the input data.
|
||||
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
|
||||
* @param num_mask Number of channels in the mask data.
|
||||
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
|
||||
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
|
||||
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
|
||||
* is not valid.
|
||||
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
|
||||
* than num_expert_group.
|
||||
* @param dtype Data type of the input data, should match the actual data type.
|
||||
* float, half, bfloat16 is supported.
|
||||
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
|
||||
normalization is performed; if normalize_mode == 1, the normalized denominator is
|
||||
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
|
||||
* the products of softmax_result mask.
|
||||
*/
|
||||
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
|
||||
float *reduce_weight,
|
||||
int *expert_id,
|
||||
const void *input,
|
||||
const void *mask,
|
||||
const int num_token,
|
||||
const int num_expert,
|
||||
const int num_mask,
|
||||
const int topk,
|
||||
const int num_expert_group,
|
||||
const int topk_group,
|
||||
const cnnlDataType_t dtype,
|
||||
const int normalize_mode);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
|
||||
Reference in New Issue
Block a user