Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/moe/cast_gating.mlu
2026-02-04 17:39:32 +08:00

647 lines
33 KiB
Plaintext

#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
#include "cast_gating.mluh"
#include "cnnl.h"
#include "cnrt.h"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
#define NRAM_BUFFER_SIZE (496 * 1024)
#define WRAM_BUFFER_SIZE (512 * 1024)
#define SRAM_BUFFER_SIZE (2032 * 1024)
#ifndef ONE_LINE
#define ONE_LINE 64
#endif
#ifndef LT_NUM
#define LT_NUM 64
#endif
struct castGatingTileInfo {
int32_t block = 64;
int32_t split_k_num = 8;
int32_t block_k = 256;
};
namespace kernels {
#pragma bang walign(16)
#ifndef ROW_PER_LT
#define ROW_PER_LT 4
#endif
#ifndef LT_SIZE
#define LT_SIZE 16
#endif
#ifndef WRAM_LT_MAP16_STRIDE
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
#endif
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
do { \
uint32_t align_num = 64 / src_dsize; \
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
uint32_t rem = size % 64; \
if (n) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
\
if (rem) { \
__asm__ __volatile__( \
"move.tiling.async.nram.sram.b16" \
" [%[dst_addr]], [%[src_addr]], " \
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
} \
} while (false)
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
#if __BANG_ARCH__ >= 500
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
}
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
convert_type) \
int align_n = PAD_DOWN(n, LT_NUM); \
int sn0 = ONE_LINE; \
int size_sn0 = sn0 / src_dsize; \
int sn1 = ONE_LINE / src_dsize; \
int ss1 = total_k * src_dsize; \
int sn3 = k / size_sn0; \
int sn4 = align_n / sn1; \
int ss4 = sn1 * ss1; \
int dn0 = sn0; \
int dn1 = ROW_PER_LT; \
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
int ds1 = dst_k * dst_dsize; \
int dn2 = sn1 / ROW_PER_LT; \
int ds2 = WRAM_LT_MAP16_STRIDE; \
int ds3 = sn0 * dst_dsize / src_dsize; \
int dn4 = LT_SIZE / dn2; \
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
int dn5 = align_n / LT_NUM; \
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
int rem_k = k % size_sn0; \
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
if (align_n > 0 && sn3 > 0) { \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
sram_src += align_n * total_k; \
wram_dst += align_n / LT_SIZE * dst_k; \
} \
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
if (align_n > 0 && sn3 > 0) { \
sn1 = align_n; \
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
[dst_s3] "r"(ds3)); \
sram_src += align_n * total_k; \
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
} \
if (rem_k > 0) { \
align_n = PAD_UP(n, LT_NUM); \
sn0 = rem_k * src_dsize; \
dn0 = sn0; \
__asm__ __volatile__( \
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
";\n\t" ::[dst_addr] "r"(wram_dst2), \
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
[dst_s5] "r"(0)); \
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
half *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
".cvt.rn.f32.f16()");
#endif
}
__mlu_func__ void warp_prompt_weight(float *wram_dst,
bfloat16_t *sram_src,
int32_t warp_n,
int32_t len_k,
int32_t total_k) {
#if __BANG_ARCH__ >= 500
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
#endif
}
template <typename T>
__mlu_func__ void warp_prompt_weight(T *wram_dst,
T *sram_src,
int32_t n,
int32_t len_k,
int32_t total_k) {
int32_t type_size = sizeof(T);
int32_t data_size = len_k * type_size;
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
int32_t ss0 = total_k * type_size;
int32_t count = n / LT_NUM;
for (int32_t i = 0; i < count; ++i) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
}
count = n % LT_NUM / ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
}
count = n % ROW_PER_LT;
if (count > 0) {
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
}
}
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
const int32_t &taskid,
const int32_t &taskdim,
int32_t &task_offset,
int32_t &num_cur_task) {
int32_t num_per_task = num_total_task / taskdim;
int32_t rem_idx = num_total_task % taskdim;
if (taskid < rem_idx) {
task_offset = taskid * (num_per_task + 1);
num_cur_task = num_per_task + 1;
} else {
task_offset = taskid * num_per_task + rem_idx;
num_cur_task = num_per_task;
}
}
__mlu_func__ void bidirectionBarrierOp() {
int32_t bcnt = coreDim + 1;
if (__is_ipu()) {
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
} else {
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
}
}
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
}
__mlu_func__ void warp_store(void *ddr_dst,
void *nram_src,
const int32_t data_num,
const int32_t dst_stride,
const int32_t src_stride,
const int32_t count,
const int32_t dt_size) {
if (src_stride == data_num && dst_stride == data_num) {
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
} else {
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
src_stride * dt_size, count - 1);
}
}
template <typename Tc, typename Tcc>
__mlu_func__ void splitKReduce(Tcc *workspace,
Tc *output,
int32_t M,
int32_t N,
int32_t split_k_num,
int32_t ldc) {
int32_t offset_m, cta_m;
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
if (cta_m <= 0) return;
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
Tc *output_ddr = (Tc *)output + offset_m * ldc;
for (int32_t i = 0; i < repeat; i++) {
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
int32_t data_size = N * sizeof(Tc);
int32_t data_num = current_m - 1;
if (ldc == N) {
data_size = current_m * N * sizeof(Tc);
data_num = 0;
}
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
split_k_num, 1, 1, 1);
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
N * sizeof(Tc), data_num);
workspace_ddr = workspace_ddr + block_m * N;
output_ddr = output_ddr + block_m * ldc;
}
}
template <typename Ta,
typename Tac,
typename Tb,
typename Tbc,
typename Tc,
typename Tcc,
bool EXCHANGE_AB>
__mlu_global__ void MLUCastGating(Ta *A,
Tb *B,
Tc *C,
Tcc *workspace,
int32_t M,
int32_t N,
int32_t K,
int32_t lda,
int32_t ldb,
int32_t ldc,
castGatingTileInfo split_info) {
#if __BANG_ARCH__ >= 500
int32_t block_k = split_info.block_k;
int32_t grid_dimx = split_info.split_k_num;
int32_t block = split_info.block;
int32_t grid_idx = clusterId % grid_dimx;
int32_t grid_idy = clusterId / grid_dimx;
int32_t offset_k = 0, problem_k = 0;
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
int32_t warp_m = cta_m, warp_offset_m = 0;
int32_t warp_n = cta_n, warp_offset_n = 0;
int32_t outer_loop = 0, outer_rem = 0;
if (EXCHANGE_AB) {
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
if (cta_n > block && cta_n % block != 0) {
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_n + block - 1) / block;
outer_rem = cta_n % block == 0 ? block : cta_n % block;
} else {
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
if (cta_m > block && cta_m % block != 0) {
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
if (block_tmp < block) block = block_tmp;
}
outer_loop = (cta_m + block - 1) / block;
outer_rem = cta_m % block == 0 ? block : cta_m % block;
}
int32_t size_nram_buf =
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
Tac *nbuf_a = (Tac *)nram_buffer;
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
Ta *sbuf_a = (Ta *)sram_buffer;
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
Tbc *wbuf_b = (Tbc *)wram_buffer;
int32_t a_dsize = sizeof(Ta);
int32_t b_dsize = sizeof(Tb);
int32_t k_loop_count = 0;
for (int32_t j = 0; j < outer_loop; j++) {
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
if (EXCHANGE_AB) {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
} else {
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
}
int32_t compute_total = warp_m * warp_n;
if (compute_total > 0 && __is_ipu()) {
if (!EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
}
__bang_write_zero((Tcc *)nbuf_c, compute_total);
}
int32_t i = 0;
for (; i < k_loop; i++) {
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
cta_k = i == k_loop - 1 ? rem_k : block_k;
if (__is_mpu()) {
if (EXCHANGE_AB) {
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
} else {
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
current_block - 1);
__asm__ volatile(
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
}
a_ddr = (Ta *)a_ddr + block_k;
b_ddr = (Tb *)b_ddr + block_k;
}
bidirectionBarrierOp();
if (__is_ipu() && compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__sync_io_move_compute(false, false, true, false, true, false);
if (i >= 1) {
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
}
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
// mvdma bound for EXCHANGE_AB when n==32
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
}
k_loop_count += 1;
}
if (compute_total > 0) {
__sync_io_move_compute(false, true, false, false, false, true);
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
if (EXCHANGE_AB) {
__sync_io_move_compute(true, false, false, false, false, true);
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
}
int32_t total_offset =
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
: (offset_m + warp_offset_m + block * j) * N);
Tcc *wks = (Tcc *)workspace + total_offset;
int32_t store_c_size = sizeof(Tcc);
int8_t *store_ddr = (int8_t *)wks;
int32_t dst_str = EXCHANGE_AB ? M : N;
if (grid_dimx == 1) {
// convert Tcc to Tc
dst_str = ldc;
store_ddr =
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
: (offset_m + warp_offset_m + block * j) * ldc));
}
__asm__ volatile("sync.psimd.cio;\n\t");
if (EXCHANGE_AB) {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
} else {
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
}
}
}
if (grid_dimx != 1) {
__sync_all();
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
split_info.split_k_num, ldc);
}
#endif // __BANG_ARCH__ >= 500
}
} // namespace kernels
int32_t getBlock(int32_t m,
int32_t n,
int32_t core_num,
int32_t block_k,
int32_t a_dtype_size,
int32_t b_dtype_size,
int32_t compute_dtype_size,
bool EXCHANGE_AB) {
int32_t block = 0;
if (EXCHANGE_AB) {
int32_t block_m = n;
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
(2 * block_m * compute_dtype_size) * core_num;
int32_t wram_block_n =
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
int32_t sram_block_n =
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
return block_n > 0 ? block_n : block_n_tmp;
} else {
int32_t block_n = n;
int32_t nram_block_m =
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
int32_t sram_block_m =
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
return block;
}
}
void gatingTiling(int32_t m,
int32_t n,
int32_t k,
size_t a_dtype_size,
size_t b_dtype_size,
size_t compute_dtype_size,
size_t workspace_size,
int32_t union_number,
int32_t core_num,
int32_t &block,
int32_t &split_k_num,
int32_t &block_k,
bool &EXCHANGE_AB) {
block_k = std::min(k, int32_t(512 / a_dtype_size));
split_k_num = 1;
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
if (m >= core_num * LT_NUM &&
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
EXCHANGE_AB = true;
}
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
compute_dtype_size, EXCHANGE_AB);
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
block = tmp_block;
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
for (int32_t i = total_blocks; i <= union_number; i++) {
if (union_number % i == 0) {
int32_t tmp_split_k = union_number / i;
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
if (workspace_size >= workspace_size_need) {
split_k_num = tmp_split_k;
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
if (EXCHANGE_AB && block > LT_NUM * core_num) {
block = PAD_DOWN(block, LT_NUM * core_num);
}
break;
}
}
}
}
}
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
CNdev dev;
cnCtxGetDevice(&dev);
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
}
KernelStatus invokeCastGating(cnrtQueue_t queue,
void *input,
void *filter,
void *output,
int input_row,
int expert_num,
int hidden_size,
cnnlDataType_t a_dtype,
void *workspace,
size_t workspace_size_bytes) {
if (is_arch300()) {
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (expert_num > 128) {
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
std::cerr
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
if (workspace_size_bytes > 0 && workspace == NULL) {
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
"greater than 0."
<< std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
int32_t union_number, core_num;
getContxtInfo(&union_number, &core_num);
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
cnrtDim3_t dim;
dim.x = (int32_t)func_type;
dim.y = 1;
dim.z = 1;
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
castGatingTileInfo split_info;
bool EXCHANGE_AB = false;
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
workspace_size_bytes, union_number, core_num, split_info.block,
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else if (a_dtype == CNNL_DTYPE_HALF) {
if (EXCHANGE_AB) {
kernels::MLUCastGating<float, float, half, float, float, float, true>
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
(float *)workspace, expert_num, input_row, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
} else {
kernels::MLUCastGating<half, float, float, float, float, float, false>
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
(float *)workspace, input_row, expert_num, hidden_size,
hidden_size, hidden_size, expert_num, split_info);
}
} else {
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo