This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,521 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "add_bias_activation.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
#define USE_NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 20 * 1024)
#define USE_SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 20 * 1024)
__nram__ int8_t nram_buffer[USE_NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[USE_SRAM_SIZE];
__mlu_func__ void get_expert_info(int *nram_count,
int *count_sram,
uint32_t *gather_offset,
int real_inner,
int tokens_start,
int tokens_end,
int tokens_load,
int expert_deal_start,
int expert_deal_end,
uint32_t &expert_start,
uint32_t &expert_end,
uint32_t &tokens_deal_first,
int dtype_size) {
bool record_start = false;
// loop expert to find first and last deal expert in current core
for (int expert_id = expert_deal_start; expert_id <= expert_deal_end; expert_id++) {
if (__load_nram(nram_count + expert_id + 1) > tokens_start && !record_start) {
expert_start = expert_id;
tokens_deal_first =
std::min(__load_nram(nram_count + expert_id + 1) - 1, tokens_end) - tokens_start + 1;
record_start = true;
}
if (__load_nram(nram_count + expert_id + 1) > tokens_end) {
expert_end = expert_id;
break;
}
}
// record expert offset to gather bias
__bang_write_zero(gather_offset, tokens_load);
int tokens_load_total = 0;
for (int expert_id = expert_start; expert_id <= expert_end; expert_id++) {
int tokens_expand = __load_sram((int *)count_sram + expert_id);
if (expert_id == expert_start) {
tokens_expand = tokens_deal_first;
} else if (expert_id == expert_end) {
tokens_expand = tokens_load - tokens_load_total;
}
if (tokens_expand == 0) {
continue;
}
__bang_write_value(gather_offset + tokens_load_total, tokens_expand,
(int)((expert_id - expert_deal_start) * real_inner * dtype_size));
tokens_load_total += tokens_expand;
}
}
/*************** functions for compute basic operation ***************/
template <typename T>
__mlu_func__ void add_bias(T *dst_src, T *bias, int number) {
// cycle add bias
__bang_add((T *)dst_src, (T *)dst_src, (T *)bias, number);
}
template <>
__mlu_func__ void add_bias(bfloat16_t *dst_src, bfloat16_t *bias, int number) {
#if __BANG_ARCH__ > 500
__bang_add((bfloat16_t *)dst_src, (bfloat16_t *)dst_src, (bfloat16_t *)bias, number);
#endif
}
template <typename T>
__mlu_func__ void mul_left_right(T *left, T *right, int number) {
__bang_mul((T *)left, (T *)left, (T *)right, number);
}
template <>
__mlu_func__ void mul_left_right(bfloat16_t *left, bfloat16_t *right, int number) {
#if __BANG_ARCH__ > 500
__bang_mul((bfloat16_t *)left, (bfloat16_t *)left, (bfloat16_t *)right, number);
#endif
}
__mlu_func__ void do_activation(float *input_left,
float *act_space,
int number,
float active_coef,
cnnlActivationMode_t act_type) {
if (act_type == CNNL_ACTIVATION_GELU) {
__bang_active_gelu((float *)input_left, (float *)input_left, number);
} else if (act_type == CNNL_ACTIVATION_SWISH) {
float *tmp = input_left;
if (active_coef != 1.0f) {
__bang_mul_scalar(act_space, input_left, active_coef, number);
tmp = act_space;
}
__bang_active_sigmoid((float *)act_space, (float *)tmp, number);
__bang_mul((float *)input_left, (float *)input_left, (float *)act_space, number);
}
}
/*************** functions for steps of each loop ***************/
template <typename T>
__mlu_func__ void gather_bias(T *bias_sram,
T *bias_nram,
uint32_t *gather_offset,
int expert_start,
int expert_end,
int expert_deal_start,
int tokens_deal_first,
int tokens_deal,
int inner_size,
bool is_gated) {
#if __BANG_ARCH__ > 500
if (is_gated) {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
__gather_async((T *)bias_nram + tokens_deal * inner_size, (T *)bias_sram + inner_size,
gather_offset, inner_size * sizeof(T), SRAM2NRAM, inner_size * sizeof(T),
tokens_deal);
} else {
__gather_async((T *)bias_nram, (T *)bias_sram, gather_offset, inner_size * sizeof(T), SRAM2NRAM,
inner_size * sizeof(T), tokens_deal);
}
#else
for (int i = 0; i < tokens_deal; i++) {
if (is_gated) {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
__memcpy_async((T *)bias_nram + (tokens_deal * inner_size + i * inner_size),
(int8_t *)bias_sram + inner_size * sizeof(T) + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
} else {
__memcpy_async((T *)bias_nram + i * inner_size, (int8_t *)bias_sram + gather_offset[i],
inner_size * sizeof(T), SRAM2NRAM);
}
}
#endif
}
template <typename T>
__mlu_func__ void loadBiasInput(T *input,
T *left,
T *right,
T *bias_nram,
T *bias_sram,
uint32_t *gather_offset,
size_t input_offset,
int tokens_deal,
int inner_size,
uint32_t expert_start,
uint32_t expert_end,
int expert_deal_start,
uint32_t tokens_deal_first,
bool is_gated,
bool has_bias) {
if (is_gated) {
// if gated, stride io load input, left/right inner to input_left/right
__memcpy_async((T *)left, (T *)input + input_offset, inner_size * sizeof(T), GDRAM2NRAM,
inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
__memcpy_async((T *)right, (T *)input + input_offset + inner_size, inner_size * sizeof(T),
GDRAM2NRAM, inner_size * sizeof(T), inner_size * 2 * sizeof(T), tokens_deal - 1);
} else {
// if not gated, load input to input_left total
__memcpy_async((T *)left, (T *)input + input_offset, tokens_deal * inner_size * sizeof(T),
GDRAM2NRAM);
}
if (has_bias) {
__sync_compute();
gather_bias((T *)bias_sram, (T *)bias_nram, gather_offset, expert_start, expert_end,
expert_deal_start, tokens_deal_first, tokens_deal, inner_size, is_gated);
}
}
template <typename T>
__mlu_func__ void computeAddActivation(T *bias_nram,
T *left_dst,
T *input_right,
float *input_left,
float *act_space,
int tokens_deal,
int inner_size,
bool is_gated,
bool has_bias,
float active_coef,
cnnlActivationMode_t act_type) {
int number = tokens_deal * inner_size;
if (has_bias) {
add_bias((T *)left_dst, (T *)bias_nram, number);
if (is_gated) {
add_bias((T *)input_right, (T *)bias_nram + tokens_deal * inner_size, number);
}
}
// cast half/bfloat16 to float to acvication, if float, left_dst is same as input_left
if (std::is_same<T, half>::value) {
__bang_half2float((float *)input_left, (half *)left_dst, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float((float *)input_left, (bfloat16_t *)left_dst, number);
}
#endif
// activation
do_activation(input_left, act_space, number, active_coef, act_type);
// if half/bfloat16, cast float to T to mul
if (std::is_same<T, half>::value) {
__bang_float2half((half *)input_left, (float *)input_left, number);
}
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16((bfloat16_t *)input_left, (float *)input_left, number);
}
#endif
if (is_gated) {
mul_left_right((T *)input_left, (T *)input_right, tokens_deal * inner_size);
}
}
template <typename T>
__mlu_func__ void storeOutput(T *output,
T *output_nram,
size_t output_offset,
int output_stride,
int tokens_deal,
int inner_size) {
__memcpy_async((T *)output + output_offset, (T *)output_nram, inner_size * sizeof(T), NRAM2GDRAM,
output_stride * sizeof(T), inner_size * sizeof(T), tokens_deal - 1);
}
template <typename T>
__mlu_global__ void MLUAddBiasActivationKernel(T *output,
const T *input,
const T *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
// if bias and token_count is nullptr, not add bias, only activation and gated mul.
bool has_bias = (bias != nullptr);
// 1. distrubute nram space
/* if not gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * inner_size * sizeof(T) |
---------------------------- ----------------------------------
--------------------------------------
| input_ping/pong |
| 2 * x * inner_size * sizeof(float) |
--------------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
/* if gated
---------------------------- ----------------------------------
| nram_token_count | bias_ping/pong |
| num_expert * sizeof(int) | 2 * x * real_inner * sizeof(T) |
---------------------------- ----------------------------------
-------------------------------------- ----------------------------------
| input_left_ping/pong | input_right_ping/pong |
| 2 * x * inner_size * sizeof(float) | 2 * x * inner_size * sizeof(T) |
-------------------------------------- ----------------------------------
------------------------------------------------- -------------------
| act_space | gather_offset |
| gelu: 0; silu: x * inner_size * sizeof(float) | x * sizeof(int) |
------------------------------------------------- -------------------
*/
// distribute sram
int8_t *count_sram = (int8_t *)sram_buffer;
int8_t *bias_sram = (int8_t *)count_sram + num_expert * sizeof(int);
// distrubute nram
int real_inner = (is_gated) ? inner_size * 2 : inner_size;
int bias_nram_size = (has_bias) ? real_inner * sizeof(T) : 0;
int act_space_size = (act_type == CNNL_ACTIVATION_GELU) ? 0 : inner_size * sizeof(float);
int gated_ext_size = is_gated ? sizeof(T) : 0;
int max_token_deal = (USE_NRAM_SIZE - (num_expert + 1) * sizeof(int)) /
(2 * inner_size * (sizeof(float) + gated_ext_size) + act_space_size +
2 * bias_nram_size + sizeof(int));
int8_t *nram_count = (int8_t *)nram_buffer;
int8_t *bias_nram = (int8_t *)nram_count + (num_expert + 1) * sizeof(int);
int8_t *input_left = (int8_t *)bias_nram + 2 * max_token_deal * bias_nram_size;
int8_t *input_right =
(int8_t *)input_left + 2 * ((is_gated) ? max_token_deal * inner_size * sizeof(float) : 0);
int8_t *act_space = (int8_t *)input_right +
2 * max_token_deal * inner_size * (is_gated ? sizeof(T) : sizeof(float));
int8_t *gather_offset = (int8_t *)act_space + max_token_deal * act_space_size;
// 2. cusum_token_count load to nram, because need to reuse in load bias.
if (has_bias) {
__memcpy((int *)nram_count, (int *)cusum_token_count, (num_expert + 1) * sizeof(int),
GDRAM2NRAM);
if (taskIdX == 0) {
__bang_sub((int *)bias_nram, (int *)nram_count + 1, (int *)nram_count, num_expert);
__sync();
__memcpy((int *)count_sram, (int *)bias_nram, num_expert * sizeof(int), NRAM2SRAM);
}
__sync_cluster();
}
// 3. sram loop to compute
// compute once load bias to sram due to sram_limit
int max_expert_deal = (USE_SRAM_SIZE - num_expert * sizeof(int)) / (real_inner * sizeof(T));
int real_expert = cusum_token_count == nullptr ? num_expert : expert_size;
int sram_loop_rem = real_expert % max_expert_deal;
int sram_loop = real_expert / max_expert_deal + (int)(sram_loop_rem != 0);
if (!has_bias) {
max_expert_deal = real_expert;
sram_loop = 1;
sram_loop_rem = 0;
}
for (int deal_loop = 0; deal_loop < sram_loop; deal_loop++) {
// load current bias, compute each core deal number
int expert_deal =
(deal_loop == (sram_loop - 1) && sram_loop_rem != 0) ? sram_loop_rem : max_expert_deal;
int expert_deal_start = deal_loop * max_expert_deal + start_expert_id;
int expert_deal_end = expert_deal_start + expert_deal - 1;
__sync_all();
if (has_bias && __is_mpu()) {
__memcpy((T *)bias_sram, (T *)bias + deal_loop * max_expert_deal * real_inner,
expert_deal * real_inner * sizeof(T), GDRAM2SRAM);
}
__sync_all();
// get tokens info of each core
int tokens_total_cur = total_tokens;
if (has_bias) {
tokens_total_cur =
__load_nram((int *)nram_count + expert_deal_end + 1) -
(start_expert_id == 0 ? 0 : __load_nram((int *)nram_count + expert_deal_start));
} else if (cusum_token_count != nullptr) {
tokens_total_cur = __load_gdram(cusum_token_count + expert_deal_end + 1) -
__load_gdram(cusum_token_count + expert_deal_start);
}
if (sram_loop != 1) {
tokens_total_cur = ((int *)nram_count)[expert_deal_end + 1] -
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop == 0 && start_expert_id != 0) {
tokens_total_cur -= __load_nram((int *)nram_count + expert_deal_start);
}
}
int tokens_core_rem = tokens_total_cur % taskDim;
int tokens_cur_core = tokens_total_cur / taskDim + (taskId < tokens_core_rem);
if (tokens_cur_core == 0) {
continue;
}
// if ep, input start in current token, have a real start in total network
int real_start =
cusum_token_count != nullptr ? __load_gdram((int *)cusum_token_count + start_expert_id) : 0;
int tokens_core_start = tokens_cur_core * taskId +
(taskId < tokens_core_rem ? 0 : tokens_core_rem) +
((deal_loop == 0) ? 0 : ((int *)nram_count)[expert_deal_start]);
if (deal_loop != 0) {
tokens_core_start -= real_start;
}
uint32_t expert_start = 0;
uint32_t expert_end = 0;
uint32_t tokens_deal_first = 0;
// 4. nram loop compute
int nram_loop_rem = tokens_cur_core % max_token_deal;
int nram_loop = tokens_cur_core / max_token_deal + (int)(nram_loop_rem != 0);
int tokens_load = max_token_deal;
int tokens_compute = max_token_deal;
int tokens_store = max_token_deal;
for (int loop = 0; loop < nram_loop + 2; loop++) {
int inner_io_offset = (loop % 2) * max_token_deal * inner_size;
int inner_com_offset = ((loop + 1) % 2) * max_token_deal * inner_size;
int real_io_offset = (loop % 2) * max_token_deal * real_inner;
int real_com_offset = ((loop + 1) % 2) * max_token_deal * real_inner;
if (nram_loop_rem != 0) {
if (loop > 1 && (loop - 2) == (nram_loop - 1)) {
tokens_store = nram_loop_rem;
}
if (loop > 0 && (loop - 1) == (nram_loop - 1)) {
tokens_compute = nram_loop_rem;
}
if (loop == (nram_loop - 1)) {
tokens_load = nram_loop_rem;
}
}
int tokens_cur_start = tokens_core_start + loop * max_token_deal;
int tokens_cur_end = tokens_cur_start + tokens_load - 1;
// get current load info
if (loop < nram_loop && has_bias) {
get_expert_info((int *)nram_count, (int *)count_sram, (uint32_t *)gather_offset, real_inner,
tokens_cur_start + real_start, tokens_cur_end + real_start, tokens_load,
expert_deal_start, expert_deal_end, expert_start, expert_end,
tokens_deal_first, sizeof(T));
}
// store
if (loop > 1) {
size_t output_offset = (tokens_core_start + (loop - 2) * max_token_deal) * output_stride;
storeOutput((T *)output, (T *)((float *)input_left + inner_io_offset), output_offset,
output_stride, tokens_store, inner_size);
}
// compute
if (loop > 0 && loop <= nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_com_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_compute * inner_size);
computeAddActivation((T *)bias_nram + real_com_offset, (T *)left_dst,
(T *)input_right + inner_com_offset,
(float *)input_left + inner_com_offset, (float *)act_space,
tokens_compute, inner_size, is_gated, has_bias, active_coef, act_type);
}
// load
if (loop < nram_loop) {
T *left_dst = (T *)((float *)input_left + inner_io_offset) +
((std::is_same<T, float>::value) ? 0 : tokens_load * inner_size);
size_t input_offset = tokens_cur_start * real_inner;
loadBiasInput((T *)input, (T *)left_dst, (T *)input_right + inner_io_offset,
(T *)bias_nram + real_io_offset, (T *)bias_sram, (uint32_t *)gather_offset,
input_offset, tokens_load, inner_size, expert_start, expert_end,
expert_deal_start, tokens_deal_first, is_gated, has_bias);
}
__sync();
}
}
}
} // namespace kernels
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef) {
if (bias != NULL && cusum_token_count == NULL) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "when have bias, cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (act_type != CNNL_ACTIVATION_GELU && act_type != CNNL_ACTIVATION_SWISH) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: "
<< "activation mode only supports gelu and swish now.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)output, (const float *)input, (const float *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)output, (const half *)(input), (const half *)bias, cusum_token_count, num_expert,
total_tokens, inner_size, output_stride, is_gated, act_type, start_expert_id, expert_size,
active_coef);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeGroupAddBiasActivationKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUAddBiasActivationKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)bias,
cusum_token_count, num_expert, total_tokens, inner_size, output_stride, is_gated, act_type,
start_expert_id, expert_size, active_coef);
} else {
std::cerr << "[invokeGroupAddBiasActivationKernel]: add_bias_activation data_type not support, "
<< "only support float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,84 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#define CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Add bias and activate to all tokens. Different expert with different bias.
* If in gated mode, add bias to all input. But only activate in the
* input of [:, :inner_size]. Then multiply it to the input of [:, inner_size:].
* Else, add bias and activate to all input, has no multiply process.
* @example
* is_gated = true, num_expert = 4, total_tokens = 6, inner_size = 2
* input: (6, 4) = [[2, 4, 5, 6], [1, 4, 5, 3],
* [3, 5, 7, 8], [6, 8, 5, 3],
* [2, 3, 4 ,5], [2, 9, 2, 3]]
* bias: (4, 4) = [[1, 0, 1, 0], [0, 1, 2, 2], [2, 3, 2, 3], [1, 2, 3, 4]]
* token_count = [2, 2, 1, 1]
* first step: add bias
* [[2+1, 4+0, 5+1, 6+0], [1+1, 4+0, 5+1, 3+0],
* [3+0, 5+1, 7+2, 8+2], [6+0, 8+1, 5+2, 3+2],
* [2+2, 3+3, 4+2, 5+3], [2+1, 9+2, 2+3, 3+4]]
* second step: act and mul
* output: (6, 2) = [[act(3)*6, act(4)*6], [act(2)*6, act(4)*3],
* [act(3)*9, act(6)*10], [act(6)*7, act(9)*5],
* [act(4)*6, act(6)*8], [act(3)*5, act(11)*7]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* When is_gated is true, The shape is [total_tokens, input_size / 2].
* In this case, the input_size must be even. Otherwise the shape is [total_tokens,
* input_size]. The memory can be discontinuous in total_tokens dim. The stride is output_stride.
* @param input: Input. Pointer to the MLU memory that stores the input tokens.
* The shape is [total_tokens, input_size].
* When is_gated is true, the shape is [total_tokens, 2 * inner_size].
* Otherwise the shape is [total_tokens, inner_size].
* @param bias: Input. Pointer to the MLU memory that stores the bias. The memory must be
* continuous. When is_gated is true, the shape is [num_expert, 2 * inner_size]. Otherwise the shape
* is [num_expert, inner_size]. Bias can be nullptr. If bias is nullptr, has no add bias process.
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of token
* counts. The shape is [num_expert + 1]. If cusum_token_count
* is not nullptr, cusum_token_count, start_expert_id and
* expert_size together determine which tokens to process.
* If cusum_token_count is nullptr, process all tokens,
* the number of which is total_tokens. When bias is not nullptr,
* cusum_token_count must also not be nullptr.
* @param num_expert: The number of expert.
* @param total_tokens: The total number of tokens.
* @param inner_size: The inner size of output.
* @param output_stride: The stride of output, must be greater than or equal to inner_size.
* @param dtype: Data type.
* @param is_gated: Gated or not.
* @param act_type: The type of activation. Support gelu and swish.
* @param start_expert_id: The index of the start expert.
* @param expert_size: The number of experts to process.
* @param active_coef: The coefficient used in the swish activation.
*/
KernelStatus invokeGroupAddBiasActivationKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const int *cusum_token_count,
int num_expert,
int total_tokens,
int inner_size,
int output_stride,
cnnlDataType_t dtype,
bool is_gated,
cnnlActivationMode_t act_type,
int start_expert_id,
int expert_size,
float active_coef);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_ADD_BIAS_ACTIVATION_MLUH_

View File

@@ -0,0 +1,646 @@
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
#include "cast_gating.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
#define NRAM_BUFFER_SIZE (496 * 1024)
#define WRAM_BUFFER_SIZE (512 * 1024)
#define SRAM_BUFFER_SIZE (2032 * 1024)
#ifndef ONE_LINE
#define ONE_LINE 64
#endif
#ifndef LT_NUM
#define LT_NUM 64
#endif
struct castGatingTileInfo {
int32_t block = 64;
int32_t split_k_num = 8;
int32_t block_k = 256;
};
namespace kernels {
#pragma bang walign(16)
#ifndef ROW_PER_LT
#define ROW_PER_LT 4
#endif
#ifndef LT_SIZE
#define LT_SIZE 16
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
#endif
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
do { \
uint32_t align_num = 64 / src_dsize; \
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
uint32_t rem = size % 64; \
if (n) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
\
if (rem) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
} while (false)
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
}
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
convert_type) \
int align_n = PAD_DOWN(n, LT_NUM); \
int sn0 = ONE_LINE; \
int size_sn0 = sn0 / src_dsize; \
int sn1 = ONE_LINE / src_dsize; \
int ss1 = total_k * src_dsize; \
int sn3 = k / size_sn0; \
int sn4 = align_n / sn1; \
int ss4 = sn1 * ss1; \
int dn0 = sn0; \
int dn1 = ROW_PER_LT; \
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
int ds1 = dst_k * dst_dsize; \
int dn2 = sn1 / ROW_PER_LT; \
int ds2 = WRAM_LT_MAP16_STRIDE; \
int ds3 = sn0 * dst_dsize / src_dsize; \
int dn4 = LT_SIZE / dn2; \
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
int dn5 = align_n / LT_NUM; \
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
int rem_k = k % size_sn0; \
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
if (align_n > 0 && sn3 > 0) { \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
sram_src += align_n * total_k; \
wram_dst += align_n / LT_SIZE * dst_k; \
} \
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
if (align_n > 0 && sn3 > 0) { \
sn1 = align_n; \
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
[dst_s3] "r"(ds3)); \
sram_src += align_n * total_k; \
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
} \
if (rem_k > 0) { \
align_n = PAD_UP(n, LT_NUM); \
sn0 = rem_k * src_dsize; \
dn0 = sn0; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst2), \
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
[dst_s5] "r"(0)); \
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
half *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
bfloat16_t *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
template <typename T>
__mlu_func__ void warp_prompt_weight(T *wram_dst,
T *sram_src,
int32_t n,
int32_t len_k,
int32_t total_k) {
int32_t type_size = sizeof(T);
int32_t data_size = len_k * type_size;
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * type_size;
int32_t count = n / LT_NUM;
for (int32_t i = 0; i < count; ++i) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
}
count = n % LT_NUM / ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
}
count = n % ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
}
}
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
const int32_t &taskid,
const int32_t &taskdim,
int32_t &task_offset,
int32_t &num_cur_task) {
int32_t num_per_task = num_total_task / taskdim;
int32_t rem_idx = num_total_task % taskdim;
if (taskid < rem_idx) {
task_offset = taskid * (num_per_task + 1);
num_cur_task = num_per_task + 1;
} else {
task_offset = taskid * num_per_task + rem_idx;
num_cur_task = num_per_task;
}
}
__mlu_func__ void bidirectionBarrierOp() {
int32_t bcnt = coreDim + 1;
if (__is_ipu()) {
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
} else {
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
}
}
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
}
__mlu_func__ void warp_store(void *ddr_dst,
void *nram_src,
const int32_t data_num,
const int32_t dst_stride,
const int32_t src_stride,
const int32_t count,
const int32_t dt_size) {
if (src_stride == data_num && dst_stride == data_num) {
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
} else {
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
src_stride * dt_size, count - 1);
}
}
template <typename Tc, typename Tcc>
__mlu_func__ void splitKReduce(Tcc *workspace,
Tc *output,
int32_t M,
int32_t N,
int32_t split_k_num,
int32_t ldc) {
int32_t offset_m, cta_m;
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
if (cta_m <= 0) return;
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
Tc *output_ddr = (Tc *)output + offset_m * ldc;
for (int32_t i = 0; i < repeat; i++) {
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
int32_t data_size = N * sizeof(Tc);
int32_t data_num = current_m - 1;
if (ldc == N) {
data_size = current_m * N * sizeof(Tc);
data_num = 0;
}
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
split_k_num, 1, 1, 1);
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
N * sizeof(Tc), data_num);
workspace_ddr = workspace_ddr + block_m * N;
output_ddr = output_ddr + block_m * ldc;
}
}
template <typename Ta,
typename Tac,
typename Tb,
typename Tbc,
typename Tc,
typename Tcc,
bool EXCHANGE_AB>
__mlu_global__ void MLUCastGating(Ta *A,
Tb *B,
Tc *C,
Tcc *workspace,
int32_t M,
int32_t N,
int32_t K,
int32_t lda,
int32_t ldb,
int32_t ldc,
castGatingTileInfo split_info) {
#if __BANG_ARCH__ >= 500
int32_t block_k = split_info.block_k;
int32_t grid_dimx = split_info.split_k_num;
int32_t block = split_info.block;
int32_t grid_idx = clusterId % grid_dimx;
int32_t grid_idy = clusterId / grid_dimx;
int32_t offset_k = 0, problem_k = 0;
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
int32_t warp_m = cta_m, warp_offset_m = 0;
int32_t warp_n = cta_n, warp_offset_n = 0;
int32_t outer_loop = 0, outer_rem = 0;
if (EXCHANGE_AB) {
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
if (cta_n > block && cta_n % block != 0) {
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_n + block - 1) / block;
outer_rem = cta_n % block == 0 ? block : cta_n % block;
} else {
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
if (cta_m > block && cta_m % block != 0) {
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_m + block - 1) / block;
outer_rem = cta_m % block == 0 ? block : cta_m % block;
}
int32_t size_nram_buf =
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
Tac *nbuf_a = (Tac *)nram_buffer;
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
Ta *sbuf_a = (Ta *)sram_buffer;
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
Tbc *wbuf_b = (Tbc *)wram_buffer;
int32_t a_dsize = sizeof(Ta);
int32_t b_dsize = sizeof(Tb);
int32_t k_loop_count = 0;
for (int32_t j = 0; j < outer_loop; j++) {
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
if (EXCHANGE_AB) {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
} else {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
}
int32_t compute_total = warp_m * warp_n;
if (compute_total > 0 && __is_ipu()) {
if (!EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
}
__bang_write_zero((Tcc *)nbuf_c, compute_total);
}
int32_t i = 0;
for (; i < k_loop; i++) {
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
cta_k = i == k_loop - 1 ? rem_k : block_k;
if (__is_mpu()) {
if (EXCHANGE_AB) {
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
} else {
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
}
a_ddr = (Ta *)a_ddr + block_k;
b_ddr = (Tb *)b_ddr + block_k;
}
bidirectionBarrierOp();
if (__is_ipu() && compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__sync_io_move_compute(false, false, true, false, true, false);
if (i >= 1) {
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
}
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
// mvdma bound for EXCHANGE_AB when n==32
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
}
k_loop_count += 1;
}
if (compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
if (EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
}
int32_t total_offset =
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
: (offset_m + warp_offset_m + block * j) * N);
Tcc *wks = (Tcc *)workspace + total_offset;
int32_t store_c_size = sizeof(Tcc);
int8_t *store_ddr = (int8_t *)wks;
int32_t dst_str = EXCHANGE_AB ? M : N;
if (grid_dimx == 1) {
// convert Tcc to Tc
dst_str = ldc;
store_ddr =
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
: (offset_m + warp_offset_m + block * j) * ldc));
}
__asm__ volatile("sync.psimd.cio;\n\t");
if (EXCHANGE_AB) {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
} else {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
}
}
}
if (grid_dimx != 1) {
__sync_all();
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
split_info.split_k_num, ldc);
}
#endif // __BANG_ARCH__ >= 500
}
} // namespace kernels
int32_t getBlock(int32_t m,
int32_t n,
int32_t core_num,
int32_t block_k,
int32_t a_dtype_size,
int32_t b_dtype_size,
int32_t compute_dtype_size,
bool EXCHANGE_AB) {
int32_t block = 0;
if (EXCHANGE_AB) {
int32_t block_m = n;
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
(2 * block_m * compute_dtype_size) * core_num;
int32_t wram_block_n =
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
int32_t sram_block_n =
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
return block_n > 0 ? block_n : block_n_tmp;
} else {
int32_t block_n = n;
int32_t nram_block_m =
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
int32_t sram_block_m =
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
return block;
}
}
void gatingTiling(int32_t m,
int32_t n,
int32_t k,
size_t a_dtype_size,
size_t b_dtype_size,
size_t compute_dtype_size,
size_t workspace_size,
int32_t union_number,
int32_t core_num,
int32_t &block,
int32_t &split_k_num,
int32_t &block_k,
bool &EXCHANGE_AB) {
block_k = std::min(k, int32_t(512 / a_dtype_size));
split_k_num = 1;
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
if (m >= core_num * LT_NUM &&
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
EXCHANGE_AB = true;
}
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
compute_dtype_size, EXCHANGE_AB);
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
block = tmp_block;
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
for (int32_t i = total_blocks; i <= union_number; i++) {
if (union_number % i == 0) {
int32_t tmp_split_k = union_number / i;
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
if (workspace_size >= workspace_size_need) {
split_k_num = tmp_split_k;
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
if (EXCHANGE_AB && block > LT_NUM * core_num) {
block = PAD_DOWN(block, LT_NUM * core_num);
}
break;
}
}
}
}
}
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
}
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes) {
if (is_arch300()) {
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (expert_num > 128) {
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
std::cerr
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace_size_bytes > 0 && workspace == NULL) {
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
"greater than 0."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t union_number, core_num;
getContxtInfo(&union_number, &core_num);
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
cnrtDim3_t dim;
dim.x = (int32_t)func_type;
dim.y = 1;
dim.z = 1;
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
castGatingTileInfo split_info;
bool EXCHANGE_AB = false;
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
workspace_size_bytes, union_number, core_num, split_info.block,
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else if (a_dtype == CNNL_DTYPE_HALF) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, half, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<half, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else {
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,50 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_CAST_GATING_MLUH_
#define CSRC_KERNELS_CAST_GATING_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Convert input to float32 and do gating operation.
* @param queue: The queue for mlu.
* @param input: Input. Pointer to the MLU memory that stores the input,
* the shape must be [input_row, hidden_size].
* @param filter: Input. Pointer to the MLU memory that stores the weight,
* the shape must be [expert_num, hidden_size].
* @param output: Output. Pointer to the MLU memory that stores the output,
* the shape must be [input_row, expert_num].
* @param input_row: Input.
* @param expert_num: Input.
* @param hidden_size: Input.
* @param a_dtype: Input. The data-type of input.
* @param workspace: Input. Pointer to the MLU workspace.
* @param workspace_size_bytes: Input. The size of workspace in bytes.
* @note: a_dtype must be CNNL_DTYPE_BFLOAT16 or CNNL_DTYPE_HALF.
* expert_num must be in range [1, 128].
* If workspace is NOT NULL, workspace_size_bytes must NOT be smaller than 16 * 1024 * 1024.
* The data-type of filter and output must be float.
* cast_gating only supports MLU500 device or higher.
*/
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes);
} // namespace tmo
#endif // CSRC_KERNELS_CAST_GATING_MLUH_

View File

@@ -0,0 +1,760 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <iostream>
#include "cnrt.h"
#include "combine_result.mluh"
// clang-format off
#include <bang_device_functions_extra.h>
#include <mlu.h>
// clang-format on
#if __BANG_ARCH__ >= 592
#include <bang_fusor.h>
template <typename SrcT>
using bang_fusor = bang::experimental::fusor<SrcT>;
#endif
namespace tmo {
namespace kernels {
#define NRAM_REMAIN_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE)
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
template <typename T>
__mlu_func__ void swap(T *&ping, T *&pong) {
T *temp = ping;
ping = pong;
pong = temp;
}
#define GATHER_ASYNC_IO0(offset_type) \
__asm__ __volatile__( \
"gather.vector.async.nram.gdram.nram." #offset_type \
".io0 [%[dst]], [%[src]], [%[offset]], " \
"%[transfer_size], %[transfer_num], %[stride];\n\t" ::[dst] "r"(dst), \
[src] "r"(src_gdram), [offset] "r"(nram_offset), [transfer_size] "r"(transfer_size), \
[transfer_num] "r"(token_count), [stride] "r"(transfer_size))
#define FUSE_MUL_CVT(dst_dtype) \
__asm__ __volatile__("mult.scalar.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1]," \
" %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
#define FUSE_MULADD_CVT(dst_dtype) \
__asm__ __volatile__("muladd.nram.crn." #dst_dtype \
".f32 [%[dst]], [%[src0]], %[src1], [%[dst]]," \
" %[size], %[size];\n\t" ::[dst] "r"(dst), \
[src0] "r"(nram_input_buffer), [src1] "r"(expert_coeff), [size] "r"(size));
template <typename T>
__mlu_func__ void toFloat(float *dst, T *src, int count) {
if (std::is_same<T, half>::value) {
__bang_half2float(dst, (half *)src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(dst, (bfloat16_t *)src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
template <typename T>
__mlu_func__ void floatTo(T *dst, float *src, int count) {
if (std::is_same<T, half>::value) {
__bang_float2half_rn((half *)dst, src, count);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16_rn((bfloat16_t *)dst, src, count);
} else if (std::is_same<T, float>::value) {
__bang_add_scalar((float *)dst, (float *)src, (float)0, count);
}
}
__mlu_func__ void loadAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"ld.async.stride.nram.gdram.io0 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, GDRAM2NRAM, dststride, srcstride, seg_num);
#endif
}
__mlu_func__ void storeAsync2d(void *dst,
void *src,
int size,
int dststride,
int srcstride,
int seg_num) {
#if __BANG_ARCH__ > 500
__asm__ __volatile__(
"st.async.stride.gdram.nram.io1 [%[dst]], [%[src]],"
" %[size], %[dststride], %[srcstride], %[segnum];\n\t" ::[dst] "r"(dst),
[src] "r"(src), [size] "r"(size), [dststride] "r"(dststride), [srcstride] "r"(srcstride),
[segnum] "r"(seg_num));
#else
__memcpy_async(dst, src, size, NRAM2GDRAM, dststride, srcstride, seg_num);
#endif
}
template <typename T_IDX>
__mlu_func__ void gatherTokensAsync(void *dst,
void *src_gdram,
T_IDX *nram_offset,
int transfer_size,
int token_count) {
if (token_count <= 0 || src_gdram == nullptr) return;
#if __BANG_ARCH__ > 500
if (std::is_same<T_IDX, uint32_t>::value) {
GATHER_ASYNC_IO0(u32);
} else {
GATHER_ASYNC_IO0(u64);
}
#else
for (int k = 0; k < token_count; k++) {
__memcpy_async((int8_t *)dst + k * transfer_size,
(int8_t *)src_gdram + __load_nram(nram_offset + k), transfer_size, GDRAM2NRAM);
}
#endif
}
__mlu_func__ int getMaskAndActiveTokenCount(int *nram_token_idx,
int *nram_mask,
uint8_t *nram_mask_char,
int *nram_mask_buffer,
int begin_expert_acc_tokens,
int end_expert_acc_tokens,
int token_count,
bool expert_parallelism) {
if (!expert_parallelism) {
return token_count;
}
__bang_lt_scalar(nram_mask_buffer, nram_token_idx, end_expert_acc_tokens, token_count);
#if __BANG_ARCH__ >= 592
bang_fusor<int32_t>(nram_mask, nram_token_idx, token_count)
.ge(begin_expert_acc_tokens)
.land(nram_mask_buffer)
.cvt<float>(0);
#else
__bang_ge_scalar(nram_mask, nram_token_idx, begin_expert_acc_tokens, token_count);
__bang_and(nram_mask, nram_mask, nram_mask_buffer, token_count);
__bang_int322float((float *)nram_mask, (int *)nram_mask, token_count, 0);
#endif
__bang_filter((float *)nram_token_idx, (float *)nram_token_idx, (float *)nram_mask, token_count);
int active_token_count = __bang_count((float *)nram_mask, token_count);
return active_token_count;
}
__mlu_func__ void computeOffset0(uint64_t *nram_offset,
int *nram_idx,
uint64_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, nram_idx, token_count);
#endif
__bang_mul_scalar(nram_offset, nram_offset, mul_scalar, token_count);
__bang_add_scalar((int64_t *)nram_offset, (int64_t *)nram_offset, add_scalar, token_count);
}
__mlu_func__ void computeOffset0(uint32_t *nram_offset,
int *nram_idx,
uint32_t mul_scalar,
int64_t add_scalar,
uint32_t token_count) {
__bang_fusion(FUSION_FMA, nram_offset, (uint32_t *)nram_idx, mul_scalar, (int32_t)add_scalar,
token_count);
}
template <typename T_IDX>
__mlu_func__ void computeOffset(T_IDX *nram_token_offset,
T_IDX *nram_bias_offset,
int *nram_token_idx,
int *nram_expert_tables,
int expert_num,
int token_count,
int active_token_count,
int hidden_size,
int local_hidden_begin,
int dtype_size,
int start_expert_id,
int expert_size,
int begin_expert_acc_tokens,
bool has_bias) {
// for large tensor, convert int322int64 then do multiply and add seperately.
if (active_token_count <= 0) return;
if (has_bias) {
int *nram_bias_offset_temp = (int *)nram_token_offset;
__bang_write_zero(nram_bias_offset, active_token_count);
for (int i = start_expert_id + 1; i < start_expert_id + expert_size; i++) {
__bang_ge_scalar(nram_bias_offset_temp, nram_token_idx, nram_expert_tables[i],
active_token_count);
__bang_add((int *)nram_bias_offset, (int *)nram_bias_offset, nram_bias_offset_temp,
active_token_count);
}
__bang_add_scalar(nram_bias_offset_temp, (int *)nram_bias_offset, 0, active_token_count);
computeOffset0(nram_bias_offset, nram_bias_offset_temp, (T_IDX)hidden_size * dtype_size,
(T_IDX)local_hidden_begin * dtype_size, active_token_count);
}
int64_t offset =
((int64_t)local_hidden_begin - (int64_t)begin_expert_acc_tokens * hidden_size) * dtype_size;
computeOffset0(nram_token_offset, nram_token_idx, (T_IDX)(hidden_size * dtype_size), offset,
active_token_count);
}
template <typename T>
__mlu_func__ void mulScalarCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MUL_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MUL_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
}
#else
__bang_mul_scalar((float *)dst, nram_input_buffer, expert_coeff, size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
template <typename T>
__mlu_func__ void mulAddCvt(T *dst, float *nram_input_buffer, float expert_coeff, int size) {
#if __BANG_ARCH__ > 500
if (std::is_same<T, bfloat16_t>::value) {
FUSE_MULADD_CVT(bf16);
} else if (std::is_same<T, half>::value) {
FUSE_MULADD_CVT(f16);
} else if (std::is_same<T, float>::value) {
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
}
#else
__bang_fusion(FUSION_FMA, (float *)dst, nram_input_buffer, expert_coeff, (float *)dst, size,
size);
floatTo((T *)dst, (float *)dst, size);
#endif
}
// weightedReduceSum with EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == true, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
int32_t index[32];
float reg_weight[128];
int8_t *index_ = (int8_t *)index;
int topk_divide_4 = PAD_UP(topk, 4) / 4;
int token_use_count = 0;
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
for (int i = 0; i < topk_divide_4; i++) {
index[i] = __load_nram((int32_t *)(og_mask + t_i * topk) + i);
float *weight_begin = weight + t_i * topk + i * 4;
reg_weight[i * 4] = __load_nram(weight_begin);
if (i * 4 + 1 < topk) {
reg_weight[i * 4 + 1] = __load_nram(weight_begin + 1);
}
if (i * 4 + 2 < topk) {
reg_weight[i * 4 + 2] = __load_nram(weight_begin + 2);
}
if (i * 4 + 3 < topk) {
reg_weight[i * 4 + 3] = __load_nram(weight_begin + 3);
}
}
int first_in_expert = 0;
float expert_coeff;
for (; first_in_expert < topk - 1; first_in_expert++) {
bool in_expert_range = index_[first_in_expert];
if (!in_expert_range) continue;
expert_coeff = reg_weight[first_in_expert];
toFloat<T>(output_begin, input + token_use_count * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
token_use_count++;
break;
}
if (first_in_expert == topk - 1) {
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulScalarCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
__bang_write_zero((T *)output_begin, hidden_size);
}
} else {
for (int j = first_in_expert + 1; j < topk - 1; j++) {
bool in_expert_range = index_[j];
if (!in_expert_range) continue;
expert_coeff = reg_weight[j];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
}
if (index_[topk - 1]) {
expert_coeff = reg_weight[topk - 1];
toFloat<T>(nram_input_buffer, input + token_use_count * hidden_size, hidden_size);
token_use_count++;
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
} else {
floatTo((T *)output_begin, (float *)output_begin, hidden_size);
}
}
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
// weightedReduceSum without EP split
// input [token_count, k, hidden_size], weight [token_count, k]
// 1. input * weight
// 2. reduce: [token_count, k, hidden_size] --> [token_count, hidden_size], reduce mode add
template <typename T,
bool expert_parallelism,
typename std::enable_if<expert_parallelism == false, void *>::type = nullptr>
__mlu_func__ void weightedReduceSum(T *output,
T *input,
float *weight,
T *input_buffer,
int8_t *og_mask,
int topk,
int hidden_size,
int token_count,
bool &is_ping) {
float *nram_input_buffer =
(float *)((half *)input_buffer +
((std::is_same<T, float>::value || !is_ping) ? 0 : hidden_size));
T *output_base = output - ((std::is_same<T, float>::value || is_ping) ? 0 : hidden_size);
if (topk == 1) {
for (int i = 0; i < token_count; i++) {
float expert_coeff = __load_nram(weight + i);
toFloat<T>(nram_input_buffer, input + i * hidden_size, hidden_size);
mulScalarCvt(output + i * hidden_size, nram_input_buffer, expert_coeff, hidden_size);
}
return;
}
for (int t_i = 0; t_i < token_count; t_i++) {
float *output_begin = (float *)(output_base + t_i * hidden_size);
float expert_coeff = __load_nram(weight + t_i * topk);
toFloat<T>(output_begin, input + t_i * topk * hidden_size, hidden_size);
toFloat<T>(nram_input_buffer, input + (t_i * topk + 1) * hidden_size, hidden_size);
__bang_mul_scalar(output_begin, output_begin, expert_coeff, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + 1);
for (int k_i = 2; k_i < topk; k_i++) {
__bang_fusion(FUSION_FMA, output_begin, nram_input_buffer, expert_coeff, output_begin,
hidden_size, hidden_size);
expert_coeff = __load_nram(weight + t_i * topk + k_i);
toFloat<T>(nram_input_buffer, input + (t_i * topk + k_i) * hidden_size, hidden_size);
}
mulAddCvt((T *)output_begin, nram_input_buffer, expert_coeff, hidden_size);
}
if (!is_ping && sizeof(T) < sizeof(float)) {
__memcpy(output + (token_count - 1) * hidden_size, output + (token_count - 2) * hidden_size,
hidden_size * sizeof(T), NRAM2NRAM, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0, -hidden_size * sizeof(T), token_count - 1,
token_count * hidden_size * sizeof(T), 0);
}
is_ping = !is_ping;
}
template <typename T, typename T_IDX>
__mlu_global__ void MLUCombineMoeResultKernel(T *output,
T *input,
T *bias,
T *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {
if (__is_mpu()) {
return;
}
int local_hidden_begin = taskIdX * HIDDEN_BLOCK;
int local_hidden_size = std::min(HIDDEN_BLOCK, hidden_size - local_hidden_begin);
int task_avg_tokens = num_token / taskDimY;
int task_remain_tokens = num_token % taskDimY;
int task_tokens = task_avg_tokens + (int)(taskIdY < task_remain_tokens);
int task_token_begin = taskIdY * task_avg_tokens + std::min(taskIdY, task_remain_tokens);
if (local_hidden_size <= 0) return;
if (task_tokens <= 0) return;
constexpr int int32_dtype_size = (int)sizeof(int);
constexpr int fp32_dtype_size = (int)sizeof(float);
int pad_num_expert = PAD_UP(num_expert + 1, 32);
bool has_bias = bias != nullptr;
bool has_residual = residual != nullptr;
bool using_acc_sum = cusum_token_count != nullptr;
bool expert_parallelism = expert_size < num_expert;
int block_size = TOKEN_BLOCK * topk;
int pad_block_size = PAD_UP(block_size, 64);
int *nram_expert_tables = (int *)nram_buffer;
int *nram_token_idx = nram_expert_tables + pad_num_expert;
T_IDX *nram_token_offset = (T_IDX *)(nram_token_idx + pad_block_size);
T_IDX *nram_bias_offset = (T_IDX *)(nram_token_offset + pad_block_size);
int *nram_mask = (int *)(nram_bias_offset + (int)has_bias * pad_block_size);
T *nram_input_ping = (T *)(nram_mask + pad_block_size);
T *nram_input_pong = nram_input_ping + block_size * HIDDEN_BLOCK;
T *nram_bias_ping = nram_input_pong + block_size * HIDDEN_BLOCK;
T *nram_bias_pong = nram_bias_ping + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_ping = nram_bias_pong + (int)has_bias * block_size * HIDDEN_BLOCK;
T *nram_residual_pong = nram_residual_ping + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK;
float *nram_weight_ping =
(float *)(nram_residual_pong + (int)has_residual * TOKEN_BLOCK * HIDDEN_BLOCK);
float *nram_weight_pong = nram_weight_ping + pad_block_size;
int buffer_block_num = sizeof(T) > 2 ? 2 : 3;
T *nram_output_ping = (T *)(nram_weight_pong + pad_block_size);
T *nram_input_buffer = nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK;
T *nram_output_pong = (T *)((char *)nram_output_ping + TOKEN_BLOCK * HIDDEN_BLOCK * sizeof(T) +
buffer_block_num * HIDDEN_BLOCK * sizeof(half));
int *nram_mask_buffer = (int *)nram_token_offset;
uint8_t *nram_mask_char = (uint8_t *)(nram_output_pong + TOKEN_BLOCK * HIDDEN_BLOCK);
int init_token_count = std::min(TOKEN_BLOCK, task_tokens) * topk;
int begin_expert_acc_tokens = 0;
int end_expert_acc_tokens = num_token * topk;
if (using_acc_sum) {
__memcpy_async(nram_expert_tables, cusum_token_count, (num_expert + 1) * int32_dtype_size,
GDRAM2NRAM);
}
__memcpy_async(nram_token_idx, gather_idx + task_token_begin * topk,
init_token_count * sizeof(int), GDRAM2NRAM);
__sync_io();
if (expert_parallelism) {
begin_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id);
end_expert_acc_tokens = __load_nram(nram_expert_tables + start_expert_id + expert_size);
}
int active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, init_token_count, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables, num_expert,
init_token_count, active_token_count, hidden_size, local_hidden_begin,
(int)sizeof(T), start_expert_id, expert_size, begin_expert_acc_tokens, has_bias);
__sync_io_move_compute(true, false, false, false, false, true);
__sync_io_move_compute(false, false, true, true, false, false);
int next_active_token_count = active_token_count;
int previous_global_token_begin = 0;
int previous_token_count = 0;
bool is_ping = false;
for (int task_begin = -1; task_begin * TOKEN_BLOCK < task_tokens; task_begin++) {
int next_token_begin = (task_begin + 1) * TOKEN_BLOCK;
int next_next_token_begin = (task_begin + 2) * TOKEN_BLOCK;
bool is_last_loop = next_token_begin >= task_tokens;
bool is_last_2_loop = next_next_token_begin >= task_tokens;
int current_token_begin = task_begin * TOKEN_BLOCK;
int current_token_count = std::min(TOKEN_BLOCK, task_tokens - current_token_begin);
int next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_token_begin);
int next_next_token_count = std::min(TOKEN_BLOCK, task_tokens - next_next_token_begin);
int current_global_token_begin = task_token_begin + current_token_begin;
int next_global_token_begin = task_token_begin + next_token_begin;
int next_next_global_token_begin = task_token_begin + next_next_token_begin;
if (!is_last_loop) {
if (!is_last_2_loop) {
loadAsync2d(nram_token_idx, gather_idx + next_next_global_token_begin * topk,
next_next_token_count * topk * sizeof(int), 0, 0, 0);
}
loadAsync2d(nram_weight_ping, reduce_weight + next_global_token_begin * topk,
next_token_count * topk * fp32_dtype_size, 0, 0, 0);
if (has_residual) {
loadAsync2d(nram_residual_ping,
residual + next_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
local_hidden_size * sizeof(T), local_hidden_size * sizeof(T),
hidden_size * sizeof(T), next_token_count - 1);
}
gatherTokensAsync<T_IDX>(nram_input_ping, input, nram_token_offset,
local_hidden_size * sizeof(T), next_active_token_count);
gatherTokensAsync<T_IDX>(nram_bias_ping, bias, nram_bias_offset,
local_hidden_size * sizeof(T), next_active_token_count);
}
if (task_begin >= 1) {
storeAsync2d(
output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
if (task_begin >= 0) {
if (has_bias && active_token_count) {
__bang_add(nram_input_pong, nram_input_pong, nram_bias_pong,
active_token_count * local_hidden_size);
}
if (expert_parallelism) {
weightedReduceSum<T, true>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
} else {
weightedReduceSum<T, false>(nram_output_ping, nram_input_pong, nram_weight_pong,
nram_input_buffer, (int8_t *)nram_mask_char, topk,
local_hidden_size, current_token_count, is_ping);
}
if (has_residual) {
__bang_add((T *)nram_output_ping, (T *)nram_output_ping, nram_residual_pong,
current_token_count * local_hidden_size);
}
}
__sync_io_move_compute();
active_token_count = next_active_token_count;
if (expert_parallelism && !is_last_loop) {
__bang_float2uchar_tz((uint8_t *)nram_mask_char, (float *)nram_mask, next_token_count * topk);
}
if (!is_last_2_loop) {
next_active_token_count = getMaskAndActiveTokenCount(
nram_token_idx, nram_mask, nram_mask_char, nram_mask_buffer, begin_expert_acc_tokens,
end_expert_acc_tokens, next_next_token_count * topk, expert_parallelism);
computeOffset(nram_token_offset, nram_bias_offset, nram_token_idx, nram_expert_tables,
num_expert, next_next_token_count * topk, next_active_token_count, hidden_size,
local_hidden_begin, (int)sizeof(T), start_expert_id, expert_size,
begin_expert_acc_tokens, has_bias);
}
swap(nram_input_ping, nram_input_pong);
swap(nram_bias_ping, nram_bias_pong);
swap(nram_residual_ping, nram_residual_pong);
swap(nram_weight_ping, nram_weight_pong);
swap(nram_output_ping, nram_output_pong);
previous_global_token_begin = current_global_token_begin;
previous_token_count = current_token_count;
}
storeAsync2d(output + previous_global_token_begin * (uint64_t)hidden_size + local_hidden_begin,
nram_output_pong, local_hidden_size * sizeof(T), hidden_size * sizeof(T),
local_hidden_size * sizeof(T), previous_token_count - 1);
}
#if __BANG_ARCH__ < 500
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint32_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
template <>
__mlu_global__ void MLUCombineMoeResultKernel<bfloat16_t, uint64_t>(bfloat16_t *output,
bfloat16_t *input,
bfloat16_t *bias,
bfloat16_t *residual,
float *reduce_weight,
int *cusum_token_count,
int *gather_ids,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
int HIDDEN_BLOCK,
int TOKEN_BLOCK) {}
#endif
} // namespace kernels
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype) {
if (topk > 128 || num_expert > 1024 || hidden_size < 256) {
std::cerr << "[invokeMoeCombineResultKernel]: "
<< "currently only support topk <= 128, num_expert <= 1024 and hidden_size >= 256.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (bias != nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: currently does not support bias.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
if ((bias != nullptr || num_expert > expert_size) && cusum_token_count == nullptr) {
std::cerr << "[invokeMoeCombineResultKernel]: if has bias or expert parallelism, "
<< "cusum_token_count can not be nullptr.";
return KernelStatus::KERNEL_STATUS_FAILED;
}
size_t data_bytes = 0;
cnnlGetSizeOfDataType(dtype, &data_bytes);
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
// 480KB nram size, 48KB for token idx, token/bias offset and weight. 432KB for buffer.
// TOKEN_BLOCK * topk <= 1024 in case 32KB is enough for idx and offset.
int convert_buffer = data_bytes == 2
? 3 * hidden_size * data_bytes
: 2 * hidden_size * data_bytes; // buffer for convert bf16/fp16->fp32
int max_input_size = (432 * 1024 - convert_buffer) /
(2 * topk * data_bytes + /*input size, double buffer*/
(bias != nullptr) * 2 * topk * data_bytes + /*bias size, double buffer*/
(residual != nullptr) * 2 * data_bytes + /*residual size, double buffer*/
2 * data_bytes); /*output size, one buffer*/
int TOKEN_BLOCK = 1;
int HIDDEN_BLOCK = 1;
int HIDDEN_BLOCK_X_TOKEN_BLOCK = (max_input_size / 64) * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK < hidden_size) {
HIDDEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK;
TOKEN_BLOCK = 1;
} else {
HIDDEN_BLOCK = hidden_size;
}
// for latency case, hidden_size is large but token is small.
if (HIDDEN_BLOCK == hidden_size && hidden_size >= 4096 && num_token <= core_num * cluster_num) {
HIDDEN_BLOCK = (hidden_size + core_num - 1) / core_num;
}
HIDDEN_BLOCK = std::min(HIDDEN_BLOCK, 8 * 1024);
uint32_t task_dim_x = (hidden_size + HIDDEN_BLOCK - 1) / HIDDEN_BLOCK;
task_dim_x =
(task_dim_x < core_num) ? task_dim_x : ((task_dim_x + core_num - 1) / core_num * core_num);
uint32_t pad_dim_x = task_dim_x;
while (pad_dim_x <= cluster_num * core_num) {
if ((cluster_num * core_num % pad_dim_x == 0)) {
task_dim_x = pad_dim_x;
break;
}
pad_dim_x += core_num;
}
HIDDEN_BLOCK = (hidden_size + task_dim_x - 1) / task_dim_x;
HIDDEN_BLOCK = (HIDDEN_BLOCK + 63) / 64 * 64;
if (HIDDEN_BLOCK_X_TOKEN_BLOCK >= hidden_size) {
TOKEN_BLOCK = HIDDEN_BLOCK_X_TOKEN_BLOCK / HIDDEN_BLOCK;
}
TOKEN_BLOCK = std::min(TOKEN_BLOCK, 1024 / topk);
float max_cluster_num = core_num * cluster_num / task_dim_x;
uint32_t task_dim_y = std::min(max_cluster_num, num_token);
task_dim_y = task_dim_y < 1 ? 1 : task_dim_y;
cnrtDim3_t dim{.x = task_dim_x, .y = task_dim_y, .z = 1};
bool is_large_tensor = data_bytes * num_token * topk * hidden_size > UINT32_MAX;
if (dtype == CNNL_DTYPE_FLOAT) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<float, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<float, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(float *)output, (float *)input, (float *)bias, (float *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_HALF) {
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<half, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<half, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(half *)output, (half *)input, (half *)bias, (half *)residual, (float *)reduce_weight,
(int *)cusum_token_count, (int *)gather_idx, num_token, topk, num_expert, hidden_size,
start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeCombineResultKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (!is_large_tensor) {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint32_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
} else {
kernels::MLUCombineMoeResultKernel<bfloat16_t, uint64_t><<<dim, cnrtFuncTypeBlock, queue>>>(
(bfloat16_t *)output, (bfloat16_t *)input, (bfloat16_t *)bias, (bfloat16_t *)residual,
(float *)reduce_weight, (int *)cusum_token_count, (int *)gather_idx, num_token, topk,
num_expert, hidden_size, start_expert_id, expert_size, HIDDEN_BLOCK, TOKEN_BLOCK);
}
} else {
std::cerr << "[invokeMoeCombineResultKernel]: the current supported dtype is "
<< "among float/half/bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,85 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#define CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Sort tokens grouped by different experts based on index. Each token
* selects the topk hidden vectors, multiplies them by corresponding weights,
* and finally reduces the topk vectors for each token. This process involves
* bias and residual, calculated as (x + bias) * weight + residual.
* @example
* input:
* [[[1, 2, 1, 1],
* [1, 1, 1, 2]],
* [[2, 1, 1, 1],
* [1, 1, 1, 1]]]
* num_token = 2, topk = 2
* cusum_token_count = [0, 2, 4]
* index:
* [0, 1, 2, 3]
* weight:
* [0, 0, 1, 1]
* bias:
* [[0, 0, 0, 0],
* [1, 1, 1, 1]]
* residual:
* [[1, 1, 1, 1],
* [0, 0, 0, 0]]
* output:
* [[1, 1, 1, 1],
* [5, 4, 4, 4]]
* @param queue: The queue for mlu.
* @param output: Output. Pointer to the MLU memory that stores the result.
* The shape is [num_token, hidden_size].
* @param input: Input. Pointer to the MLU memory that stores input tokens.
* The shape is [num_token * topk, hidden_size].
* @param bias: Input. Pointer to the MLU memory that stores bias.
* The shape is [num_expert, hidden_size].
* @param residual: Input. Pointer to the MLU memory that stores residual.
* The shape is [num_token, hidden_size].
* @param reduce_weight: Input. Pointer to the MLU memory that stores reduce_weight.
* The shape is [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the cumulative sum of the
* token number of each expert. The shape is [num_expert + 1].
* @param gather_idx: Input. Pointer to the MLU memory that stores gather_idx.
* The shape is [num_token * topk].
* @param num_token: The total number of tokens.
* @param topk: The number of expert.
* @param num_expert: The number of expert.
* @param hidden_size: The size of lowest dimension.
* @param start_expert_id: The id of the first processed expert.
* @param expert_size: The number of processed experts.
* @param dtype: Data type.
* @note Currently does not support bias.
*/
KernelStatus invokeMoeCombineResultKernel(cnrtQueue_t queue,
void *output,
const void *input,
const void *bias,
const void *residual,
const float *reduce_weight,
const int *cusum_token_count,
const int *gather_idx,
int num_token,
int topk,
int num_expert,
int hidden_size,
int start_expert_id,
int expert_size,
cnnlDataType_t dtype);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_COMBINE_RESULT_MLUH_

View File

@@ -0,0 +1,219 @@
#include <bang_device_functions_extra.h>
#include <mlu.h>
#include "cnnl.h"
#include "cnrt.h"
#include "expand_input.mluh"
namespace tmo {
namespace kernels {
#define RESERVED_SIZE (32 * 1024)
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
#define MEMCPY_BURST_SIZE 128
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_func__ void ExpandInputKernel(void *output,
void *input,
int *index,
int num_token,
int hidden_size,
int num_index) {
void *input_ptr = input;
uint64_t input_size = data_size * num_token * hidden_size;
// whether SRAM_BUFFER_SIZE can hold the input data
// sram_enable is true ==> is_nram_output is true
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
if (sram_enable) {
input_ptr = (void *)sram_buffer;
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
__sync_cluster();
}
if (__is_mpu()) {
return;
}
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
uint32_t total_num = num_index;
uint32_t base = total_num / maxTaskDim;
uint32_t tail = total_num - base * maxTaskDim;
if (taskId >= maxTaskDim) {
return;
}
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
// nram
/*
* first: compute offset: index[i] * data_size
* second: gather data
* -------------------------------------------------
* addr || index/offset | output |
* type || int32_t/T_offset | T |
* num || n | n * hidden_size |
* -------------------------------------------------
*/
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
uint32_t per_num =
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
int *index_base = index + batch_step;
T_offset *nram_offset = (T_offset *)nram_buffer;
int32_t *nram_index;
if (std::is_same<T_offset, int64_t>::value) {
nram_index = (int32_t *)nram_offset + per_num;
} else {
nram_index = (int32_t *)nram_offset;
}
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
uint32_t repeat = batch_per_core / per_num;
uint32_t remain = batch_per_core - repeat * per_num;
uint32_t deal_num = per_num;
uint32_t is_remain = remain != 0 ? 1 : 0;
for (int32_t i = 0; i < repeat + is_remain; i++) {
if (i == repeat) {
deal_num = remain;
}
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
int32_t *index_ptr = index_base + i * per_num;
// index -> offset
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
if (std::is_same<T_offset, uint64_t>::value) {
#if __BANG_ARCH__ > 592
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
#else
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
#endif
}
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
// copy
if (is_nram_output) {
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
// GDRAM or SRAM -> NRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
dir);
}
#endif // __BANG_ARCH__
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
} else {
// GDRAM -> GDRAM
#if __BANG_ARCH__ >= 592 // gather requires
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
hidden_size * data_size, deal_num);
#else
for (int32_t j = 0; j < deal_num; j++) {
T_offset offset_value = *(nram_offset + j);
int8_t *input_offset = (int8_t *)input + offset_value;
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
hidden_size * data_size, GDRAM2GDRAM);
}
#endif // __BANG_ARCH__
}
}
}
// T_offset: uint32_t or uint64_t
template <size_t data_size, typename T_offset>
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
void *hidden_state,
int *gather_idx,
int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
int total_expert_num,
int start_expert_id,
int expert_count) {
int32_t num_index = num_token * topk;
int *gather_start_idx = (int *)gather_idx;
if (cusum_token_count != nullptr) {
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
*((int *)cusum_token_count + start_expert_id);
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
}
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
num_token, hidden_size, num_index);
}
// instantiate kernels
#define INSTANTIATE_ONE(data_size, T_offset) \
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
void *, void *, int *, int *, int, int, int, int, int, int);
INSTANTIATE_ONE(1, uint32_t)
INSTANTIATE_ONE(2, uint32_t)
INSTANTIATE_ONE(4, uint32_t)
INSTANTIATE_ONE(8, uint32_t)
// large tensor
INSTANTIATE_ONE(1, uint64_t)
INSTANTIATE_ONE(2, uint64_t)
INSTANTIATE_ONE(4, uint64_t)
INSTANTIATE_ONE(8, uint64_t)
} // namespace kernels
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
size_t type_size = 0;
cnnlGetSizeOfDataType(data_type, &type_size);
int max_cluster_num =
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
expert_count);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,81 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#define CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Gathers slices from hidden_state at axis 1 according to gather_idx and cusum_token_count.
* @example
* hidden_state:
* [[1, 2, 3, 4],
* [5, 6, 7, 8],
* [9, 10, 11, 12]]
* gather_idx:
* [[1, 0, 2, 2, 1, 0]]
* cusum_token_count: NULL
* num_token = 3
* hidden_size = 4
* topk = 2
* expand_hidden_state:
* [[5, 6, 7, 8],
* [1, 2, 3, 4],
* [9, 10, 11, 12],
* [9, 10, 11, 12],
* [5, 6, 7, 8],
* [1, 2, 3, 4]]
* @param queue: The queue for mlu.
* @param hidden_state: Input. Pointer to the MLU memory that store the input,
* the shape must be [num_token, hidden_size].
* @param gather_idx: Input. Pointer to the MLU memory that stores the index,
* the shape must be [num_token * topk].
* @param cusum_token_count: Input. Pointer to the MLU memory that stores the prefix sum of
* token_count. If cusum_token_count is not NULL, the shape must be [total_expert_num + 1]. The
* gather operation will be performed as follows: if cusum_token_count is not NULL: index =
* gather_idx[cusum_token_count[start_expert_id]:cusum_token_count[start_expert_id+expert_count]]
* expand_hidden_state = hidden_state[index]
* else:
* index = gather_idx[:]
* expand_hidden_state = hidden_state[index]
* @param expand_hidden_state: Output. Pointer to the MLU memory that stores the output,
* if cusum_token_count is not NULL, the shape shoule be [num_index * topk ,hidden_size] in
* which num_index =
* cusum_token_count[start_expert_id+expert_count]-cusum_token_count[start_expert_id]. Otherwise,
* the shape should be [num_token * topk, hidden_size].
* @param num_token: the number of token.
* @param hidden_size: the slice size.
* @param topk: the number of topk.
* @param data_type: Data type of hidden_state.
* @param total_expert_num: the total number of expert.
* @param start_expert_id: the first expert id.
* @param expert_count: the number of experts currently being processed.
*/
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
void *expand_hidden_state,
const void *hidden_state,
const int *gather_idx,
const int *cusum_token_count,
int num_token,
int hidden_size,
int topk,
cnnlDataType_t data_type,
int total_expert_num,
int start_expert_id,
int expert_count);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_EXPAND_INPUT_MLUH_

View File

@@ -0,0 +1,935 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <vector>
#include "gen_idx.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define NRAM_BUFFER_SIZE ((__MLU_NRAM_SIZE__ - 16) * 1024)
#define SRAM_BUFFER_SIZE ((__MLU_SRAM_SIZE__ - 8) * 1024)
#define ALIGN_16 (16)
#define EXPERT_AVG_COUNT_TEST (0)
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__nram__ const int range[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
// Generate integer sequence data from 0 to length-1
__mlu_func__ void generateIntSeq(int *dst, int length) {
int count = 64;
__bang_move(dst, range, std::min(count, length) * sizeof(int));
while (count < length) {
__bang_add_scalar(dst + count, dst, (int)count, std::min(count, length - count));
count *= 2;
}
}
// genIdx Block kernel, use only 1 core to process
__mlu_global__ void launchMoeGenIdxBlockKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
/* NRAM space */
// Total occupy: (4 * token_total_num + 2 * num_expert) * sizeof(int)
// --------------------------------------------------------------
// | expert_id | sorted_idx |gen_idx_onchip|cur_expert_result|
// | combine_idx | expand_idx | | scatter_offset |
// |num_token*topk|num_token*topk|num_token*topk| num_token*topk |
// --------------------------------------------------------------
// ------------------------------
// |token_count|token_count_presum|
// | | |
// | num_expert| num_expert |
// ------------------------------
uint32_t token_total_num = num_token * topk;
// num align to 16, size align to 64B
uint32_t align_total_num = (token_total_num + ALIGN_16 - 1) >> 4 << 4;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *sorted_idx_onchip = (int8_t *)expert_id_onchip + align_total_num * sizeof(int);
int8_t *gen_idx_onchip = (int8_t *)sorted_idx_onchip + align_total_num * sizeof(int);
int8_t *cur_expert_result = (int8_t *)gen_idx_onchip + align_total_num * sizeof(int);
int8_t *token_count_onchip = (int8_t *)cur_expert_result + align_total_num * sizeof(int);
int8_t *token_count_presum_onchip = (int8_t *)token_count_onchip + num_expert * sizeof(int);
int8_t *scatter_offset = cur_expert_result; // reuse cur_expert space
#if __BANG_ARCH__ >= 592
int8_t *combine_idx_onchip = expert_id_onchip; // reuse expert_it space
#endif
int8_t *expand_idx_onchip = sorted_idx_onchip; // reuse sorted_idx space
// Load current core input expert_id and generate int sequence
__memcpy_async((int *)expert_id_onchip, (int *)expert_id, token_total_num * sizeof(int),
GDRAM2NRAM);
generateIntSeq((int *)gen_idx_onchip, token_total_num);
__sync();
// Initialize sort idx offset
uint32_t sorted_idx_offset = 0;
// Initialize token count first presum with 0
((int *)token_count_presum_onchip)[0] = 0;
bool need_cusum_token_count = bool(cusum_token_count != nullptr);
// Loop on each expert, eq, count, filter index
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert,
token_total_num);
// Use filter to sort gen_idx, output with sorted_idx_offset
uint32_t cur_expert_count =
__bang_filter(((float *)sorted_idx_onchip) + sorted_idx_offset, (float *)gen_idx_onchip,
(float *)cur_expert_result, token_total_num);
sorted_idx_offset += cur_expert_count;
((int *)token_count_onchip)[cur_expert] = cur_expert_count;
// Compute cusum token count and store
if (need_cusum_token_count) {
((int *)token_count_presum_onchip)[cur_expert + 1] = sorted_idx_offset;
}
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
#endif
__sync_compute();
// Store token_count and cusum token count
__memcpy_async((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int),
NRAM2GDRAM);
if (need_cusum_token_count) {
__memcpy_async((int *)cusum_token_count, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Use sorted idx to generate gather idx for expand and combine
#if __BANG_ARCH__ >= 592
// scatter_offset = sorted_idx mul_scalar sizeof(int);
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
token_total_num);
#else
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), token_total_num);
#endif
__sync_compute();
#if __BANG_ARCH__ >= 592
// scatter_async to NRAM
__scatter_async((int *)combine_idx_onchip, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2NRAM, sizeof(int), (unsigned short)token_total_num);
#endif
// expand_idx = sorted_idx div(topk)
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, token_total_num);
// Store expand_idx and combine_idx
__sync_compute();
__memcpy_async((int *)gather_expand_idx, (int *)expand_idx_onchip, token_total_num * sizeof(int),
NRAM2GDRAM);
#if __BANG_ARCH__ >= 592
__sync_move();
__memcpy_async((int *)gather_combine_idx, (int *)combine_idx_onchip,
token_total_num * sizeof(int), NRAM2GDRAM);
#else
// 370 directly scatter to GDRAM
__scatter((int *)combine_idx_align_addr, (int *)gen_idx_onchip, (uint32_t *)scatter_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)token_total_num);
#endif
}
// Only MLU500 series support NRAM2SRAM scatter direction
__mlu_func__ void scatterSeqSram(int *dst, int *src, uint32_t *offset, int length) {
#if __BANG_ARCH__ >= 592
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)32768);
seg_offset += 32768;
}
if (seg_remain > 0) {
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2SRAM, sizeof(int), (unsigned short)seg_remain);
}
#endif
}
// Scatter sequence, transfer size is sizeof(int)
__mlu_func__ void scatterSeqDram(int *dst, int *src, uint32_t *offset, int length) {
// When length larger than 65535(maximum segnum in bang_scatter),
// and src/offset address should align to 64B
int seg_repeat = length / 32768;
int seg_remain = length % 32768;
int seg_offset = 0;
for (int seg = 0; seg < seg_repeat; seg++) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)32768);
#endif
seg_offset += 32768;
}
if (seg_remain > 0) {
#if __BANG_ARCH__ >= 592
__scatter_async((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset,
sizeof(int), NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#else
__scatter((int *)dst, ((int *)src) + seg_offset, ((uint32_t *)offset) + seg_offset, sizeof(int),
NRAM2GDRAM, sizeof(int), (unsigned short)seg_remain);
#endif
}
}
// 1. Get token count
__mlu_func__ void getTokenCount(int *token_count,
int *expert_id,
int token_cur_core,
int cur_token_start,
int num_expert) {
// 1. Partition on [num_token*topk],
// each core for-loop on all expert_id, use eq and count instructions,
// use AtomicAdd to accumulate all expert_id token counts, on GDRAM.
// And sync for all cores.
// NRAM:
// ------------------------------------------------------
// |expert_id_onchip|cur_expert_result|expert_count_onchip|
// | deal_num | deal_num | num_expert |
// ------------------------------------------------------
uint32_t deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 2;
int8_t *expert_id_onchip = (int8_t *)nram_buffer;
int8_t *cur_expert_result = (int8_t *)expert_id_onchip + deal_num * sizeof(int);
int8_t *expert_count_onchip = cur_expert_result + deal_num * sizeof(int);
// Current core data loop
uint32_t repeat = token_cur_core / deal_num;
uint32_t remain = token_cur_core % deal_num;
uint32_t total_repeat = repeat + (int)(remain > 0);
uint32_t token_addr_offset = cur_token_start;
// Initialize token_count with 0
if (taskId == 0) {
__gdramset((int *)token_count, num_expert, 0);
}
// Sync for initialize token_count
__sync_all_ipu();
// Initialize expert count onchip with 0
if (token_cur_core > 0) {
__bang_write_zero((int *)expert_count_onchip, num_expert);
}
// actual num in loop
int cur_deal_num = deal_num;
for (int i = 0; i < total_repeat; i++) {
if (i == total_repeat - 1 && remain > 0) {
cur_deal_num = remain;
}
// Load current core input expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
cur_deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += cur_deal_num;
// Loop on each expert, eq, count
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, cur_deal_num);
// NOTE: __bang_count() only support floating data type
uint32_t cur_expert_count = __bang_count((float *)cur_expert_result, cur_deal_num);
((int *)expert_count_onchip)[cur_expert] += cur_expert_count;
}
}
// AtomicAdd(reduce) all cores token count results
if (token_cur_core > 0) {
__bang_atomic_reduce_add((int *)token_count, (int *)expert_count_onchip, num_expert);
}
// Sync for all cores, get accumulate of token_count
__sync_all_ipu();
}
// 2. Get token count presum, for each expert index start address after sorting
__mlu_func__ void getTokenCountPresum(int *token_count_presum,
int *token_count,
const int num_expert) {
// 2. After first process, already get token_count.
// Then use one core to pre-sum on token_count, consider size of int32,
// first expert id start address should be zero.
// to get each expert id start address after sorting, store to workspace,
// token_count_presum.
// And sync for all cores.
// NRAM:
// load token_count to token_count_presum[1~num_expert+1],
// for i = 0 to num_expert:
// token_count_presum[i+1] += token_count_presum[i]
// store token_count_presum[0~num_expert]
// -------------------------
// |token_count_presum_onchip|
// | {0}, num_expert |
// -------------------------
if (taskId == 0) {
// Initialize count presum onchip with a first 0
int8_t *token_count_presum_onchip = nram_buffer;
((int *)token_count_presum_onchip)[0] = 0;
// Load token_count with an offset of 1
__memcpy(((int *)token_count_presum_onchip) + 1, (int *)token_count, num_expert * sizeof(int),
GDRAM2NRAM);
// Calculate presum of token count by each expert
for (int cur_expert = 0; cur_expert < num_expert; cur_expert++) {
((int *)token_count_presum_onchip)[cur_expert + 1] +=
((int *)token_count_presum_onchip)[cur_expert];
}
// Store token count presum to workspace
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// Sync for all cores, get presum of token count
__sync_all_ipu();
}
__mlu_func__ void modifyTokenCountAndPresum(int *token_count_presum,
int *token_count,
const uint32_t token_total_num,
const int num_expert) {
uint32_t token_count_avg = token_total_num / num_expert;
uint32_t expert_remain_num = token_total_num % num_expert;
int8_t *token_count_onchip = nram_buffer;
int8_t *token_count_presum_onchip = token_count_onchip + num_expert * sizeof(int);
((int *)token_count_presum_onchip)[0] = 0;
for (int i = 0; i < num_expert; i++) {
((int *)token_count_onchip)[i] =
(i < expert_remain_num) ? token_count_avg + 1 : token_count_avg;
((int *)token_count_presum_onchip)[i + 1] =
((int *)token_count_presum_onchip)[i] + ((int *)token_count_onchip)[i];
}
__memcpy((int *)token_count, (int *)token_count_onchip, num_expert * sizeof(int), NRAM2GDRAM);
__memcpy((int *)token_count_presum, (int *)token_count_presum_onchip,
(num_expert + 1) * sizeof(int), NRAM2GDRAM);
}
// 3. Get expert position index after sorting
__mlu_func__ void getSortedIdx(int *sorted_idx,
int *expert_id,
int *token_count_presum,
const int token_total_num,
const int num_expert,
const int expert_cur_core,
const int cur_expert_start,
const int cur_expert_end) {
// 3. Partition on num_expert, each core generate position index from 0,
// and for-loop on all expert_id data, use eq with own each expert_id,
// and filter on index, stores to each expert_id start address of
// sorted_idx on workspace.
// And sync for all cores.
// NRAM:
// -------------------------------------------------------------------
// |expert_id_onchip|cur_expert_result|gen_idx_onchip|filter_idx_onchip|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// |expert_start_addr|
// | num_expert |
// -----------------
// Calculate new deal_num of sorting process
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int) - num_expert) / 4;
// Each core deal with whole token expert_id data
int repeat = token_total_num / deal_num;
int remain = token_total_num % deal_num;
int token_addr_offset = 0;
int8_t *expert_id_onchip = nram_buffer;
int8_t *cur_expert_result = expert_id_onchip + deal_num * sizeof(int);
int8_t *gen_idx_onchip = cur_expert_result + deal_num * sizeof(int);
int8_t *filter_idx_onchip = gen_idx_onchip + deal_num * sizeof(int);
int8_t *expert_start_addr = filter_idx_onchip + deal_num * sizeof(int);
// When num_expert < taskDim, not all cores need to sort
if (expert_cur_core > 0) {
// Generate position index from 0
if (deal_num <= token_total_num) {
generateIntSeq((int *)gen_idx_onchip, deal_num);
} else { // only remainder part
generateIntSeq((int *)gen_idx_onchip, token_total_num);
}
// Initialize expert start address with presum of token count
__memcpy((int *)expert_start_addr, (int *)token_count_presum, num_expert * sizeof(int),
GDRAM2NRAM);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core expert_id
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
token_addr_offset += deal_num;
// Loop for current core expert, eq, filter position index
// use filter, store to sorted_idx[expert_start_addr]
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, deal_num);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, deal_num);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
// Update address offset of current expert
((int *)expert_start_addr)[cur_expert] = cur_expert_offset + cur_expert_count;
}
}
// Update position index for each data loop
__bang_add_scalar((int *)gen_idx_onchip, (int *)gen_idx_onchip, (int)(deal_num), deal_num);
}
// remainder part
if (remain > 0) {
__memcpy((int *)expert_id_onchip, ((int *)expert_id) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
for (int cur_expert = cur_expert_start; cur_expert <= cur_expert_end; cur_expert++) {
__bang_eq_scalar((int *)cur_expert_result, (int *)expert_id_onchip, cur_expert, remain);
int cur_expert_offset = ((int *)expert_start_addr)[cur_expert];
// NOTE: __bang_filter() only support floating data type
uint32_t cur_expert_count =
__bang_filter((float *)filter_idx_onchip, (float *)gen_idx_onchip,
(float *)cur_expert_result, remain);
// Store to the corresponding address of sorted_idx
if (cur_expert_count > 0) {
__memcpy(((int *)sorted_idx) + cur_expert_offset, (int *)filter_idx_onchip,
cur_expert_count * sizeof(int), NRAM2GDRAM);
}
}
}
}
// Sync for all cores, get position index after sorting
__sync_all_ipu();
}
// 4. Get gather index for expand and combine
template <bool is_sram_scatter>
__mlu_func__ void getGatherIdx(int *gather_expand_idx,
int *gather_combine_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4. Partition on [num_token*topk],
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// gather_expand_idx = sorted_idx / topk
// update sequence
// NRAM:
// -------------------------------------------------------------------
// |sorted_idx_onchip|expand_idx_onchip|scatter_offset|scatter_sequence|
// | deal_num | deal_num | deal_num | deal_num |
// -------------------------------------------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
// scatter dst GDRAM addr should align to 64B
int *combine_idx_align_addr = (int *)((uint64_t)(gather_combine_idx) >> 6 << 6);
int combine_idx_align_offset = (int)(gather_combine_idx - combine_idx_align_addr);
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_offset = expand_idx_onchip + deal_num * sizeof(int);
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
deal_num);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), deal_num);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, deal_num);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
if (is_sram_scatter) {
// if scatter to SRAM, need to sync compute with mv
__sync_move();
}
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
if (is_sram_scatter) {
__bang_mul_scalar((int *)scatter_offset, (int *)sorted_idx_onchip, (int)(sizeof(int)),
remain);
} else {
// GDRAM addr should align to 64B
__bang_fusion(FUSION_FAM, (int *)scatter_offset, (int *)sorted_idx_onchip,
combine_idx_align_offset, (int)(sizeof(int)), remain);
}
// Sync for scatter
__sync_compute();
if (is_sram_scatter) {
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
remain);
} else {
// Scatter to output gather_combine_idx
scatterSeqDram((int *)combine_idx_align_addr, (int *)scatter_sequence,
(uint32_t *)scatter_offset, remain);
}
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
// 4.1 Get gather combine index on SRAM
__mlu_func__ void getCombineIdxSram(int *sorted_idx,
const int token_cur_core,
const int cur_token_start) {
// 4.1 Partition on [num_token*topk], with only 1 union
// load sorted_idx onchip,
// generate sequence according to position index from 0, add token offset
// gather_combine_idx = scatter(seq, sorted_idx)
// update sequence
// NRAM:
// -------------------------------
// |scatter_offset|scatter_sequence|
// | deal_num | deal_num |
// -------------------------------
// Calculate new deal_num of generate gather index
// NOTE: deal_num should align to 64 Bytes, because bang_scatter() constraints
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *scatter_offset = nram_buffer;
int8_t *scatter_sequence = scatter_offset + deal_num * sizeof(int);
// Generate position index from 0
// Add base offset to sequence according to current core token start address
if (token_cur_core > 0) {
if (deal_num <= token_cur_core) {
generateIntSeq((int *)scatter_sequence, deal_num);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
deal_num);
} else { // only remainder part
generateIntSeq((int *)scatter_sequence, token_cur_core);
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)token_addr_offset,
token_cur_core);
}
}
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, deal_num * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), deal_num);
// Sync for scatter
__sync_compute();
// Scatter to SRAM
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset,
deal_num);
__sync_move();
// Add offset to sequence and token_address
__bang_add_scalar((int *)scatter_sequence, (int *)scatter_sequence, (int)deal_num, deal_num);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)scatter_offset, ((int *)sorted_idx) + token_addr_offset, remain * sizeof(int),
GDRAM2NRAM);
// offset = sorted_idx * sizeof(int), counted in bytes
__bang_mul_scalar((int *)scatter_offset, (int *)scatter_offset, (int)(sizeof(int)), remain);
// Sync for scatter
__sync_compute();
scatterSeqSram((int *)sram_buffer, (int *)scatter_sequence, (uint32_t *)scatter_offset, remain);
}
}
// 4.2 Get gather expand index
__mlu_func__ void getExpandIdx(int *gather_expand_idx,
int *sorted_idx,
const int token_cur_core,
const int cur_token_start,
const int topk) {
// 4.2 Partition on [num_token*topk],
// load sorted_idx onchip,
// gather_expand_idx = sorted_idx / topk
// NRAM:
// -----------------------------------
// |sorted_idx_onchip|expand_idx_onchip|
// | deal_num | deal_num |
// -----------------------------------
// Calculate new deal_num of generate gather index
int deal_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 2;
int repeat = token_cur_core / deal_num;
int remain = token_cur_core % deal_num;
int token_addr_offset = cur_token_start;
int8_t *sorted_idx_onchip = nram_buffer;
int8_t *expand_idx_onchip = sorted_idx_onchip + deal_num * sizeof(int);
// repeat part
for (int i = 0; i < repeat; i++) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
deal_num * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, deal_num);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
deal_num * sizeof(int), NRAM2GDRAM);
token_addr_offset += deal_num;
}
// remainder part
if (remain > 0) {
// Load current core sorted_idx
__memcpy((int *)sorted_idx_onchip, ((int *)sorted_idx) + token_addr_offset,
remain * sizeof(int), GDRAM2NRAM);
// expand_idx_onchip = sorted_idx / topk
__bang_div((int *)expand_idx_onchip, (int *)sorted_idx_onchip, topk, remain);
// Store expand idx
__memcpy(((int *)gather_expand_idx) + token_addr_offset, (int *)expand_idx_onchip,
remain * sizeof(int), NRAM2GDRAM);
}
}
__mlu_global__ void launchMoeGenIdxKernel(int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
// Store token count presum result, shape [num_expert + 1]
int *token_count_presum = (cusum_token_count != nullptr) ? cusum_token_count : (int *)workspace;
// Store position index after sorting, shape [num_token*topk]
int *sorted_idx = ((int *)workspace) + num_expert + 1;
// Calculate partition information for different processes
// Partition on [num_token*topk]
uint32_t token_total_num = num_token * topk;
uint32_t token_cur_core = token_total_num / taskDim;
uint32_t token_remain_num = token_total_num % taskDim;
token_cur_core += (uint32_t)(taskId < token_remain_num);
// Current core range according to partition on [num_token*topk]
uint32_t cur_token_start = (taskId < token_remain_num)
? token_cur_core * taskId
: token_cur_core * taskId + token_remain_num;
// Partition on [num_expert]
uint32_t expert_cur_core = num_expert / taskDim;
uint32_t expert_remain_num = num_expert % taskDim;
expert_cur_core += (uint32_t)(taskId < expert_remain_num);
// Current core range according to partition on [num_expert]
uint32_t cur_expert_start = (taskId < expert_remain_num)
? expert_cur_core * taskId
: expert_cur_core * taskId + expert_remain_num;
uint32_t cur_expert_end = cur_expert_start + expert_cur_core - 1;
// Use Union1 SRAM to scatter, only MLU500 series support now
#if __BANG_ARCH__ >= 592
bool is_sram_scatter = token_total_num * sizeof(int) < SRAM_BUFFER_SIZE;
#else
bool is_sram_scatter = false;
#endif
if (__is_ipu()) {
// 1. Get token count
getTokenCount((int *)token_count, (int *)expert_id, token_cur_core, cur_token_start,
num_expert);
// 2. Get presum of token count
getTokenCountPresum((int *)token_count_presum, (int *)token_count, num_expert);
// 3. Get expert position index after sorting
getSortedIdx((int *)sorted_idx, (int *)expert_id, (int *)token_count_presum, token_total_num,
num_expert, expert_cur_core, cur_expert_start, cur_expert_end);
}
#if EXPERT_AVG_COUNT_TEST
// NOTE: test avg expert code here:
if (__is_ipu() && taskId == 0) {
modifyTokenCountAndPresum((int *)token_count_presum, (int *)token_count, token_total_num,
num_expert);
}
__sync_cluster();
#endif
// 4. Get gather index for expand and combine
if (is_sram_scatter) {
// Only use Union1 SRAM
uint32_t scatter_idx_cur_core = token_total_num / 4;
uint32_t scatter_idx_remain_num = token_total_num % 4;
scatter_idx_cur_core += (uint32_t)(taskId < scatter_idx_remain_num);
uint32_t cur_idx_start = (taskId < scatter_idx_remain_num)
? scatter_idx_cur_core * taskId
: scatter_idx_cur_core * taskId + scatter_idx_remain_num;
// Only Union1 task type,
// deal once num is same with deal_num in getGatherIdx,
// which means only 1 repeat to generate both expand and combine idx on NRAM
const int deal_once_num = (NRAM_BUFFER_SIZE / sizeof(int)) / 4;
if (taskDim <= 4 || token_total_num < deal_once_num) {
if (taskId < 4) {
if (__is_ipu()) {
getGatherIdx<true>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
scatter_idx_cur_core, cur_idx_start, topk);
// sync for ipu and mpu
__sync_cluster();
} else {
// sync for ipu and mpu
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
}
} else {
// If taskDim > 4, use first union to generate combine idx,
// use other union to generate expand idx
if (taskId < 4) {
if (__is_ipu()) {
// Scatter combine idx to SRAM
getCombineIdxSram((int *)sorted_idx, scatter_idx_cur_core, cur_idx_start);
__sync_cluster();
} else {
__sync_cluster();
__memcpy_async((int *)gather_combine_idx, (int *)sram_buffer,
token_total_num * sizeof(int), SRAM2GDRAM);
}
} else {
// Other union generate expand idx
if (__is_ipu()) {
uint32_t expand_dim = taskDim - 4;
uint32_t expand_id = taskId - 4;
uint32_t expand_token_cur_core = token_total_num / expand_dim;
uint32_t expand_token_remain_num = token_total_num % expand_dim;
expand_token_cur_core += (uint32_t)(expand_id < expand_token_remain_num);
uint32_t expand_cur_token_start =
(expand_id < expand_token_remain_num)
? expand_token_cur_core * expand_id
: expand_token_cur_core * expand_id + expand_token_remain_num;
getExpandIdx((int *)gather_expand_idx, (int *)sorted_idx, expand_token_cur_core,
expand_cur_token_start, topk);
}
}
}
} else {
// not use SRAM to generate both expand and combine idx
if (__is_ipu()) {
getGatherIdx<false>((int *)gather_expand_idx, (int *)gather_combine_idx, (int *)sorted_idx,
token_cur_core, cur_token_start, topk);
}
}
// step 5 does not need MPU
if (__is_mpu()) {
return;
}
} // end of kernel
} // namespace kernels
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
const int token_total_num = num_token * topk;
// For partition on num_token*topk, single core processes at least 128 num
const int single_core_num_limit = 1024;
int need_core_num = std::ceil(float(token_total_num) / single_core_num_limit);
// When partition on num_expert, each core at least processes one expert
need_core_num = std::max(num_expert, need_core_num);
// When consider UnionX cnrt func type, reset cluster_num
if (token_total_num <= 4096) { // Block
cnrtFunctionType_t k_type = cnrtFuncTypeBlock;
cnrtDim3_t k_dim{1, 1, 1};
// Block kernel does not need workspace
kernels::launchMoeGenIdxBlockKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, expert_id, num_token,
num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
} else if (need_core_num <= 4) { // Union1
cluster_num = 1;
} else if (need_core_num <= 8) { // Union2
cluster_num = std::min(cluster_num, 2);
} else if (need_core_num <= 16) { // Union4
cluster_num = std::min(cluster_num, 4);
} else if (need_core_num <= 32) { // Union8
cluster_num = std::min(cluster_num, 8);
}
cnrtFunctionType_t k_type;
cnrtDim3_t k_dim{1, 1, 1};
// Find max UnionX cnrt func type
if (cluster_num == 1) {
k_type = cnrtFuncTypeUnion1;
k_dim.x = 4;
} else if (cluster_num < 4) { // cluster num is 2 or 3
k_type = cnrtFuncTypeUnion2;
k_dim.x = 8;
} else if (cluster_num < 8) { // cluster num is 4,5,6,7
k_type = cnrtFuncTypeUnion4;
k_dim.x = 16;
} else { // cluster num larger than 8
k_type = cnrtFuncTypeUnion8;
k_dim.x = 32;
}
// The expert_id is int data type
kernels::launchMoeGenIdxKernel<<<k_dim, k_type, queue>>>(
gather_expand_idx, gather_combine_idx, token_count, cusum_token_count, workspace, expert_id,
num_token, num_expert, topk);
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
#undef EXPERT_AVG_COUNT_TEST // undef test macro
} // namespace tmo

View File

@@ -0,0 +1,58 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#define CSRC_KERNELS_MOE_GEN_IDX_MLUH_
#include <vector>
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Apply generate MOE index operation, which performs the following
* tasks:
* - 1. Generate gather_expand_idx and gather_combine_idx.
* - 2. Output token_count, the token number of each expert.
* - 3. Prepare inputs and outputs address for group_gemm.
* @param queue: The queue of mlu.
* @param gather_expand_idx: Output. Pointer to the MLU memory that stores the
* gather index for expand hidden state operation, the shape must be
* [num_token * topk].
* @param gather_combine_idx: Output. Pointer to the MLU memory that stores the
* gather index for combine MOE operation, the shape must be
* [num_token * topk].
* @param token_count: Output. Pointer to the MLU memory that stores the token
* number of each expert, the shape must be [num_expert].
* @param cusum_token_count: Output. Pointer to the MLU memory that stores the
* cumulative sum of the token number of each expert, the shape must be
* [num_expert + 1]. It can be set to nullptr if don't need cusum output.
* @param workspace: Input. A pointer to the extra workspace required in the
* operation, the size must be larger than
* (num_expert + 1 + num_token * topk) multiplied by the size of uint32.
* @param expert_id: Input. Pointer to the MLU memory that stores the expert id
* of each token, the shape must be [num_token, topk].
* @param num_token: The number of tokens.
* @param num_expert: The number of experts.
* @param topk: The number of expert selected by each token.
*/
KernelStatus invokeMoeGenIdxKernel(cnrtQueue_t queue,
int *gather_expand_idx,
int *gather_combine_idx,
int *token_count,
int *cusum_token_count,
void *workspace,
const void *expert_id,
const int num_token,
const int num_expert,
const int topk);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_GEN_IDX_MLUH_

View File

@@ -0,0 +1,21 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_MOE_MLUH_
#define CSRC_KERNELS_MOE_MOE_MLUH_
#include "add_bias_activation.mluh"
#include "combine_result.mluh"
#include "expand_input.mluh"
#include "gen_idx.mluh"
#include "softmax_topk.mluh"
#endif // CSRC_KERNELS_MOE_MOE_MLUH_

View File

@@ -0,0 +1,602 @@
#include <mlu.h>
#include <cassert>
#include <iostream>
#include <limits>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "softmax_topk.mluh"
namespace tmo {
namespace kernels {
#define SCATTER_ALIGN (64) // align for __scatter()
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
#define SRAM_SIZE (__MLU_SRAM_SIZE__ * 1024 - 32 * 1024)
#define TILING_ALIGN (64)
#define DIV_UP(x, y) ((x) / (y) + (int)((x) % (y) > 0))
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_SIZE];
#define __TRANS_TILING(TYPE, CONVERT) \
__asm__ volatile("trans.tiling." TYPE \
" [%[dst]], [%[src]]," \
"%[in0], %[in1], %[is1], %[in2], %[is2], %[in3], %[is3], %[in4]," \
"%[is4], %[in5], %[is5]," \
"%[dn0], %[dn1], %[ds1], %[dn2], %[ds2], %[dn3], %[ds3], %[dn4]," \
"%[ds4], %[dn5], %[ds5]" CONVERT ::[dst] "r"(dst), \
[src] "r"(src), [in0] "r"(in0), [in1] "r"(in1), [is1] "r"(is1), [in2] "r"(in2), \
[is2] "r"(is2), [in3] "r"(in3), [is3] "r"(is3), [in4] "r"(in4), [is4] "r"(is4), \
[in5] "r"(in5), [is5] "r"(is5), [dn0] "r"(dn0), [dn1] "r"(dn1), [ds1] "r"(ds1), \
[dn2] "r"(dn2), [ds2] "r"(ds2), [dn3] "r"(dn3), [ds3] "r"(ds3), [dn4] "r"(dn4), \
[ds4] "r"(ds4), [dn5] "r"(dn5), [ds5] "r"(ds5));
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void __mlvm_trans(DST_DTYPE *dst,
const SRC_DTYPE *src,
const uint32_t in0,
const uint32_t in1,
const uint32_t is1,
const uint32_t in2,
const uint32_t is2,
const uint32_t in3,
const uint32_t is3,
const uint32_t in4,
const uint32_t is4,
const uint32_t in5,
const uint32_t is5,
const uint32_t dn0,
const uint32_t dn1,
const uint32_t ds1,
const uint32_t dn2,
const uint32_t ds2,
const uint32_t dn3,
const uint32_t ds3,
const uint32_t dn4,
const uint32_t ds4,
const uint32_t dn5,
const uint32_t ds5) {
if (SRAM2NRAM == dir && std::is_same<DST_DTYPE, float>::value) {
if (std::is_same<SRC_DTYPE, float>::value) {
__TRANS_TILING("nram.sram.b32", ";")
} else if (std::is_same<SRC_DTYPE, half>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.f16();")
#if __BANG_ARCH__ >= 500
} else if (std::is_same<SRC_DTYPE, bfloat16_t>::value) {
__TRANS_TILING("nram.sram.b16", ", .cvt.f32.bf16();")
#endif
}
}
}
/* 将shape为[h,w]的数据转置为[w,h](带转数)分4块分别进行处理。
* dst: dst地址
* src: src地址
* h: h方向大小
* w: w方向大小
*/
template <typename SRC_DTYPE, typename DST_DTYPE, mluMemcpyDirection_t dir>
__mlu_func__ void transhw2wh(DST_DTYPE *dst, SRC_DTYPE *src, uint32_t h, uint32_t w) {
uint32_t align_num = TILING_ALIGN / sizeof(SRC_DTYPE);
uint32_t w_align = w / align_num;
uint32_t w_rem = w % align_num;
uint32_t h_align = h / align_num;
uint32_t h_rem = h % align_num;
uint32_t in0 = TILING_ALIGN, dn0 = TILING_ALIGN;
uint32_t in1 = align_num, is1 = w * sizeof(SRC_DTYPE);
uint32_t in3 = w_align, is3 = TILING_ALIGN;
uint32_t in4 = h_align, is4 = w * TILING_ALIGN;
uint32_t dn1 = align_num, ds1 = h * sizeof(DST_DTYPE);
uint32_t dn3 = in3, ds3 = h * align_num * sizeof(DST_DTYPE);
uint32_t dn4 = in4, ds4 = align_num * sizeof(DST_DTYPE);
/* 1. h_align * w_align */
if (w_align > 0 && h_align > 0) {
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst, src, in0, in1, is1, 1, 0, in3, is3, in4, is4, 1, 0,
dn0, dn1, ds1, 1, 0, dn3, ds3, dn4, ds4, 1, 0);
}
/* 2. h_align * w_rem */
if (w_rem > 0 && h_align > 0) {
SRC_DTYPE *src_temp = src + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = TILING_ALIGN;
in1 = align_num;
is1 = w * sizeof(SRC_DTYPE);
in4 = h_align;
is4 = w * TILING_ALIGN;
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 3. h_rem * w_align */
if (w_align > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w;
DST_DTYPE *dst_temp = dst + h_align * align_num;
in0 = TILING_ALIGN;
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
in4 = w_align;
is4 = TILING_ALIGN;
dn1 = align_num;
ds1 = h * sizeof(DST_DTYPE);
dn4 = in4;
ds4 = h * align_num * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, in4, is4,
1, 0, dn0, dn1, ds1, 1, 0, 1, 0, dn4, ds4, 1, 0);
}
/* 4. h_rem * w_rem */
if (w_rem > 0 && h_rem > 0) {
SRC_DTYPE *src_temp = src + h_align * align_num * w + w_align * align_num;
DST_DTYPE *dst_temp = dst + w_align * align_num * h + h_align * align_num;
in0 = w_rem * sizeof(SRC_DTYPE);
dn0 = h_rem * sizeof(SRC_DTYPE);
in1 = h_rem;
is1 = w * sizeof(SRC_DTYPE);
dn1 = w_rem;
ds1 = h * sizeof(DST_DTYPE);
__mlvm_trans<SRC_DTYPE, DST_DTYPE, dir>(dst_temp, src_temp, in0, in1, is1, 1, 0, 1, 0, 1, 0, 1,
0, dn0, dn1, ds1, 1, 0, 1, 0, 1, 0, 1, 0);
}
}
__mlu_func__ void getTopk(float *value_buffer,
uint32_t *index_buffer,
float *src_buffer,
float *compute_buffer,
float *max_buffer,
float *temp_buffer,
uint32_t *i_buffer,
uint32_t *col_buffer,
uint32_t topk,
uint32_t num_expert_group,
uint32_t col,
uint32_t row,
uint32_t value_index_stride,
uint32_t group_size,
bool is_deal_group) {
__bang_write_value((float *)temp_buffer, col, -INFINITY); // set -inf vector
for (int k = 0; k < topk; k++) {
if (is_deal_group) {
__bang_maxpool_index((uint32_t *)value_buffer + k * col, max_buffer, col, 1, num_expert_group,
1, num_expert_group, 1, 1);
__bang_fusion(FUSION_FMA, col_buffer, (uint32_t *)value_buffer + k * col, col, i_buffer, col,
col);
} else {
__bang_maxpool_value_index(value_buffer + k * col, max_buffer, col, 1, row, 1, row, 1, 1,
value_index_stride);
__bang_fusion(FUSION_FMA, col_buffer, index_buffer + k * col, col, i_buffer, col, col);
}
#if __BANG_ARCH__ >= 592
__bang_mul_scalar(col_buffer, col_buffer, sizeof(float), col); // index in byte
__scatter(max_buffer, temp_buffer, col_buffer, sizeof(uint32_t), NRAM2NRAM, sizeof(uint32_t),
col); // replace max value with -inf
#else
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram(col_buffer + i);
max_buffer[index] = -INFINITY;
}
#endif
#if __BANG_ARCH__ < 500
if (is_deal_group) {
for (int i = 0; i < col; i++) {
uint32_t index = __load_nram((uint32_t *)value_buffer + k * col + i);
__memcpy(compute_buffer + i * row + index * group_size,
src_buffer + i * row + index * group_size, group_size * sizeof(float), NRAM2NRAM);
}
}
#endif
}
#if __BANG_ARCH__ >= 592
if (is_deal_group) {
__bang_transpose(index_buffer, (uint32_t *)value_buffer, topk, col);
__bang_mul_scalar((uint32_t *)value_buffer, i_buffer, row * sizeof(float), col);
__bang_move(value_buffer, value_buffer, col * sizeof(uint32_t), col * sizeof(uint32_t), 0,
topk - 1);
__bang_transpose((uint32_t *)compute_buffer, (uint32_t *)value_buffer, topk, col);
__bang_fusion(FUSION_FMA, index_buffer, index_buffer, group_size * sizeof(float),
(uint32_t *)compute_buffer, col * topk, col * topk);
__gather(compute_buffer, src_buffer, (uint32_t *)index_buffer, group_size * sizeof(float),
NRAM2NRAM, group_size * sizeof(float), col * topk);
__bang_write_value(src_buffer, row * col, -INFINITY);
__scatter(src_buffer, compute_buffer, index_buffer, group_size * sizeof(float), NRAM2NRAM,
group_size * sizeof(float), col * topk);
}
#endif
}
template <typename T>
__mlu_func__ void computeSoftmaxTopk(T *sram_buffer,
T *load_buffer,
float *src_buffer,
float *compute_buffer,
float *group_max_buffer,
float *nramout_value,
uint32_t *nramout_index,
uint32_t *i_buffer,
uint32_t *col_buffer,
float *softmax_buffer,
uint32_t row,
uint32_t nram_compute_col_num,
uint32_t mask_num,
uint32_t nram_max_col_num,
uint32_t topk,
int num_expert_group,
uint32_t topk_group,
uint32_t top_num,
uint32_t nram_col_offset,
int normalize_mode,
bool valid_mask,
bool split_mask) {
uint32_t nram_compute_num = nram_compute_col_num * row;
// convert to float for half/bf16 datatype
if (std::is_same<T, half>::value) {
__bang_half2float(src_buffer, (half *)load_buffer, nram_compute_num);
} else if (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(src_buffer, (bfloat16_t *)load_buffer, nram_compute_num);
}
// transpose [col, row] to [row, col]. To accelerate max/sum compute with maxpool/sumpool.
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
// compute softmax
int tmp = 0x3fb8aa3b;
float log2e = *(float *)&tmp; // for exp
// src_buffer reuse as buffer for max/sum.
__bang_maxpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // max
__bang_fusion(FUSION_FSM, compute_buffer, compute_buffer, src_buffer, log2e, nram_compute_num,
nram_compute_col_num);
__bang_pow2(compute_buffer, compute_buffer, nram_compute_num); // exp(input - max)
__bang_sumpool(src_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1); // sum
__bang_recip(src_buffer, src_buffer, nram_compute_col_num); // 1/sum
__bang_cycle_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_num,
nram_compute_col_num);
__sync_cluster();
// move mask and compute
if (valid_mask) {
if (!split_mask) {
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
if (std::is_same<T, half>::value) {
__memcpy((half *)compute_buffer + mask_num * row, sram_buffer, mask_num * row * sizeof(T),
SRAM2NRAM);
__bang_half2float((float *)compute_buffer, (half *)compute_buffer + mask_num * row,
mask_num * row);
} else if (std::is_same<T, bfloat16_t>::value) {
__memcpy((bfloat16_t *)compute_buffer + mask_num * row, sram_buffer,
mask_num * row * sizeof(T), SRAM2NRAM);
__bang_bfloat162float((float *)compute_buffer,
(bfloat16_t *)compute_buffer + mask_num * row, mask_num * row);
} else {
__memcpy(compute_buffer, sram_buffer, mask_num * row * sizeof(T), SRAM2NRAM);
}
__bang_cycle_mul(src_buffer, src_buffer, compute_buffer, nram_compute_col_num * row,
mask_num * row);
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
} else {
transhw2wh<T, float, SRAM2NRAM>(src_buffer, sram_buffer + nram_col_offset * row,
nram_compute_col_num, row);
__sync();
__bang_mul(compute_buffer, compute_buffer, src_buffer, nram_compute_col_num * row);
}
}
if (normalize_mode == 2) {
__bang_sumpool(softmax_buffer, compute_buffer, nram_compute_col_num, row, 1, row, 1, 1, 1);
}
if (num_expert_group <= 1) {
// num_expert_group <= 1, maintain original topk calculation logic
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * topk * sizeof(float), 0, false);
} else {
// num_expert_group > 1, use grouped_topk calculation logic
uint32_t group_size = row / num_expert_group;
__bang_transpose(src_buffer, compute_buffer, row, nram_compute_col_num);
__bang_maxpool(group_max_buffer, compute_buffer, nram_compute_col_num, num_expert_group,
group_size, 1, group_size, 1, 1);
__bang_write_value(compute_buffer, row * nram_compute_col_num, -INFINITY);
// get topk_group
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, group_max_buffer,
(float *)nramout_index, i_buffer, col_buffer, topk_group, num_expert_group,
nram_compute_col_num, row, nram_max_col_num * topk * sizeof(float), group_size, true);
// get topk
#if __BANG_ARCH__ < 500
__bang_transpose(src_buffer, compute_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, src_buffer, compute_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#else
__bang_transpose(compute_buffer, src_buffer, nram_compute_col_num, row);
getTopk(nramout_value, nramout_index, src_buffer, compute_buffer, compute_buffer, src_buffer,
i_buffer, col_buffer, topk, num_expert_group, nram_compute_col_num, row,
nram_max_col_num * top_num * sizeof(float), 0, false);
#endif
} // end else
// normalize result
if (normalize_mode == 1) {
// compute_buffer reuse as buffer for sum.
__bang_sumpool(compute_buffer, nramout_value, nram_compute_col_num, topk, 1, topk, 1, 1, 1);
__bang_recip(compute_buffer, compute_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
} else if (normalize_mode == 2) {
__bang_recip(compute_buffer, softmax_buffer, nram_compute_col_num);
__bang_cycle_mul(nramout_value, nramout_value, compute_buffer, topk * nram_compute_col_num,
nram_compute_col_num);
}
// transpose back. src and dst of transpose can not be the same address.
__bang_transpose(compute_buffer, nramout_value, topk, nram_compute_col_num);
__bang_transpose((uint32_t *)nramout_value, nramout_index, topk, nram_compute_col_num);
}
template <typename T>
__mlu_global__ void MLUSoftmaxTopkKernel(T *input,
T *mask,
int *index_out,
float *value_out,
int col,
int row,
int mask_num,
int topk,
int num_expert_group,
int topk_group,
int normalize_mode) {
bool valid_mask = (mask != nullptr);
int top_num = topk >= topk_group ? topk : topk_group;
uint32_t nram_low_space =
PAD_UP((row * 2 + top_num * 2 + 2 + (normalize_mode == 2) + num_expert_group) * sizeof(float),
SCATTER_ALIGN);
if (num_expert_group <= 1) {
nram_low_space =
PAD_UP((row * 2 + topk * 2 + 2 + (normalize_mode == 2)) * sizeof(float), SCATTER_ALIGN);
}
uint32_t nram_max_col_num = (NRAM_SIZE) / nram_low_space;
if (nram_max_col_num > col / taskDim + (col % taskDim > 0)) {
nram_max_col_num = col / taskDim + (col % taskDim > 0);
}
nram_max_col_num = PAD_DOWN(nram_max_col_num, SCATTER_ALIGN / sizeof(float));
if (nram_max_col_num <= 0) {
nram_max_col_num = SCATTER_ALIGN / sizeof(float);
}
uint32_t nram_deal_num = nram_max_col_num * row;
uint32_t batch = col / mask_num;
// nram split:
// |--------------------------|--------------------------|--------------------|...
// | size: nram/2 -col*topk*2 | size: nram/2 -col*topk*2 |col*num_expert_group|...
// | src_buffer | compute_buffer | group_max_buffer |...
// |--------------------------|--------------------------|--------------------|...
// |----------------------------------------|---------------|--------------|
// | nram_col_num*3 | col*topk | col*topk |
// | i_buffer | col_buffer | softmax_buffer | nramout_value | nramout_index|
// |----------------------------------------|---------------|--------------|
float *src_buffer = (float *)nram_buffer;
float *compute_buffer = src_buffer + PAD_UP(nram_deal_num, SCATTER_ALIGN / sizeof(float));
float *group_max_buffer = compute_buffer + nram_deal_num;
uint32_t *i_buffer = (uint32_t *)group_max_buffer + num_expert_group * nram_max_col_num;
if (num_expert_group <= 1) {
i_buffer = (uint32_t *)group_max_buffer;
}
uint32_t *col_buffer = i_buffer + nram_max_col_num;
float *softmax_buffer = (float *)col_buffer + nram_max_col_num;
if (normalize_mode != 2) {
softmax_buffer = (float *)col_buffer;
}
float *nramout_value = softmax_buffer + nram_max_col_num;
uint32_t *nramout_index = (uint32_t *)nramout_value + top_num * nram_max_col_num;
if (num_expert_group <= 1) {
nramout_index = (uint32_t *)nramout_value + topk * nram_max_col_num;
}
T *load_buffer = (T *)src_buffer;
if (std::is_same<T, half>::value || std::is_same<T, bfloat16_t>::value) {
load_buffer = load_buffer + nram_deal_num;
}
// set i_buffer
for (uint32_t i = 0; i < nram_max_col_num; i++) {
i_buffer[i] = i;
}
// input[batch, mask, low], mask[mask, low]
if (nram_max_col_num >= mask_num) { // nram can deal complete mask
bool split_mask = false;
uint32_t batch_seg = nram_max_col_num / mask_num;
uint32_t batch_rem = batch % batch_seg;
uint32_t batch_seg_num = batch / batch_seg + (batch_rem > 0);
int repeat = DIV_UP(batch_seg_num, taskDim);
for (int i = 0; i < repeat; i++) {
uint32_t seg_id = i * taskDim + taskId;
uint32_t sram_load_num = mask_num * row;
uint32_t sram_load_offset = 0;
uint32_t nram_compute_col_num = (seg_id == batch_seg_num - 1 && batch_rem > 0)
? batch_rem * mask_num
: batch_seg * mask_num;
uint32_t nram_load_num = seg_id < batch_seg_num ? nram_compute_col_num * row : 0;
uint32_t nram_store_num = seg_id < batch_seg_num ? nram_compute_col_num * topk : 0;
uint32_t nram_load_offset = seg_id * batch_seg * mask_num * row;
uint32_t nram_store_offset = seg_id * batch_seg * mask_num * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_compute_col_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, 0, normalize_mode,
valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
} else {
bool split_mask = true;
uint32_t mask_seg = nram_max_col_num;
uint32_t mask_rem = mask_num % mask_seg;
uint32_t mask_seg_num = mask_num / mask_seg + (mask_rem > 0);
uint32_t sram_mask_seg_num = DIV_UP(mask_seg_num, coreDim);
uint32_t sram_mask_rem = mask_num % sram_mask_seg_num;
uint32_t sram_average_mask_num = mask_num / sram_mask_seg_num;
for (int i = taskIdY; i < sram_mask_seg_num * batch; i += taskDimY) {
uint32_t batch_idx = i / sram_mask_seg_num;
uint32_t mask_idx = i % sram_mask_seg_num;
uint32_t sram_deal_mask_num = sram_average_mask_num + (mask_idx < sram_mask_rem);
uint32_t sram_load_num = sram_deal_mask_num * row;
uint32_t sram_mask_offset = mask_idx < sram_mask_rem
? mask_idx * (sram_average_mask_num + 1)
: mask_idx * sram_average_mask_num + sram_mask_rem;
uint32_t sram_load_offset = sram_mask_offset * row;
uint32_t nram_average_mask_num = sram_deal_mask_num / taskDimX;
uint32_t nram_mask_rem = sram_deal_mask_num % taskDimX;
uint32_t nram_deal_mask_num = nram_average_mask_num + (taskIdX < nram_mask_rem);
uint32_t nram_load_num = nram_deal_mask_num * row;
uint32_t nram_col_offset = taskIdX < nram_mask_rem
? taskIdX * (nram_average_mask_num + 1)
: taskIdX * nram_average_mask_num + nram_mask_rem;
uint32_t nram_load_offset = (batch_idx * mask_num + sram_mask_offset + nram_col_offset) * row;
uint32_t nram_store_num = nram_deal_mask_num * topk;
uint32_t nram_store_offset =
(batch_idx * mask_num + sram_mask_offset + nram_col_offset) * topk;
// Load
if (valid_mask) {
__memcpy_async(sram_buffer, mask + sram_load_offset, sram_load_num * sizeof(T), GDRAM2SRAM);
}
if (nram_load_num > 0) {
__memcpy(load_buffer, input + nram_load_offset, nram_load_num * sizeof(T), GDRAM2NRAM);
}
// Compute
computeSoftmaxTopk<T>((T *)sram_buffer, load_buffer, src_buffer, compute_buffer,
group_max_buffer, nramout_value, nramout_index, i_buffer, col_buffer,
softmax_buffer, row, nram_deal_mask_num, mask_num, nram_max_col_num,
topk, num_expert_group, topk_group, top_num, nram_col_offset,
normalize_mode, valid_mask, split_mask);
// Store
if (nram_store_num > 0) {
__memcpy(value_out + nram_store_offset, compute_buffer, nram_store_num * sizeof(float),
NRAM2GDRAM);
__memcpy(index_out + nram_store_offset, nramout_value, nram_store_num * sizeof(int),
NRAM2GDRAM);
}
__sync_cluster();
}
}
}
} // namespace kernels
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
int top_num = topk >= topk_group ? topk : topk_group;
if (num_expert_group <= 1) {
if (num_expert > (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
} else {
if (num_expert >
(NRAM_SIZE - (top_num * 2 + 2 + num_expert_group) * sizeof(float)) / 2 / sizeof(float)) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: num_expert is too large, currently not supported."
<< "Supported max num_expert:"
<< (NRAM_SIZE - (topk * 2 + 3) * sizeof(float)) / 2 / sizeof(float)
<< ". Current num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (topk > num_expert) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: topk is larger than num_expert."
<< "topk:" << topk << ". num_expert:" << num_expert;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (num_expert_group > 1) {
if (mask != nullptr) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, mask should be nullptr";
}
if (num_expert % num_expert_group != 0) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, num_expert should be"
<< "divisible by num_expert_group, but now num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk_group <= 0 || topk_group > num_expert_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk_group should be"
<< "larger than 0 and less than or equal to num_expert_group, but now topk_group"
<< topk_group << ", num_expert group:" << num_expert_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (topk > (num_expert / num_expert_group) * topk_group) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: if num_expert_group > 1, topk should be less"
<< "than or equal to (num_expert / num_expert_group) * topk_group, but now"
<< "topk :" << topk << ", num_expert:" << num_expert
<< ", num_expert_group:" << num_expert_group << ", topk_group:" << topk_group;
return KernelStatus::KERNEL_STATUS_FAILED;
}
}
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(float *)input, (float *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(half *)input, (half *)mask, expert_id, reduce_weight, num_token, num_expert, num_mask,
topk, num_expert_group, topk_group, normalize_mode);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: MLU300 devices do not support bfloat16."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUSoftmaxTopkKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
(bfloat16_t *)input, (bfloat16_t *)mask, expert_id, reduce_weight, num_token, num_expert,
num_mask, topk, num_expert_group, topk_group, normalize_mode);
} else {
std::cerr << "[invokeMoeSoftmaxTopkKernel]: source type not supported ";
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#define CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_
#include "../kernel_utils.h"
#include "cnnl.h"
namespace tmo {
/**
* @brief Execute MOE Softmax Top-K Kernel.
*
* This function executes the MOE Softmax Top-K Kernel, which computes
* the Top-K values along a specified dimension after applying softmax to the input data.
* It is specifically designed for reduction along the lowest dimension.
*
* @param queue CNRT queue used to specify the queue for execution.
* @param reduce_weight Pointer to store the Top-K values.
* The shape must be [num_token, topk].
* @param expert_id Pointer to store the indices of the Top-K values.
* The shape must be [num_token, topk].
* @param input Pointer to the input data containing the values to be computed.
* The shape must be [num_token, num_expert].
* @param mask Pointer to the input data containing the mask value to be computed after
* computing softmax, Mask can be nullptr, which means no need to compute,
* otherwise the shape and datatype of mask should be the same as input.
* @param num_token Number of channels in the input data.
* @param num_expert Specified dimension. Note that num_expert should not exceed 32768.
* @param num_mask Number of channels in the mask data.
* @param topk Number of Top-K values to compute. topk should not be larger than num_expert.
* @param num_expert_group Group numbers of num_expert. If num_expert_group > 0, num_expert
* should be divisible by num_expert_group. Otherwise, num_expert_group and topk_group
* is not valid.
* @param topk_group Number of Top-K group values to compute. Topk_group should not be larger
* than num_expert_group.
* @param dtype Data type of the input data, should match the actual data type.
* float, half, bfloat16 is supported.
* @param normalize_mode Whether and how to normalize the output, if normalize_mode == 0, no
normalization is performed; if normalize_mode == 1, the normalized denominator is
the sum of topk; if normalize_mode == 2, the normalized denominator is the sum of
* the products of softmax_result mask.
*/
KernelStatus invokeMoeSoftmaxTopkKernel(cnrtQueue_t queue,
float *reduce_weight,
int *expert_id,
const void *input,
const void *mask,
const int num_token,
const int num_expert,
const int num_mask,
const int topk,
const int num_expert_group,
const int topk_group,
const cnnlDataType_t dtype,
const int normalize_mode);
} // namespace tmo
#endif // CSRC_KERNELS_MOE_SOFTMAX_TOPK_MLUH_