647 lines
33 KiB
Plaintext
647 lines
33 KiB
Plaintext
#include <stdint.h>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include "cast_gating.mluh"
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
namespace tmo {
|
|
#define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y)))
|
|
|
|
#define NRAM_BUFFER_SIZE (496 * 1024)
|
|
#define WRAM_BUFFER_SIZE (512 * 1024)
|
|
#define SRAM_BUFFER_SIZE (2032 * 1024)
|
|
|
|
#ifndef ONE_LINE
|
|
#define ONE_LINE 64
|
|
#endif
|
|
|
|
#ifndef LT_NUM
|
|
#define LT_NUM 64
|
|
#endif
|
|
|
|
struct castGatingTileInfo {
|
|
int32_t block = 64;
|
|
int32_t split_k_num = 8;
|
|
int32_t block_k = 256;
|
|
};
|
|
|
|
namespace kernels {
|
|
#pragma bang walign(16)
|
|
#ifndef ROW_PER_LT
|
|
#define ROW_PER_LT 4
|
|
#endif
|
|
|
|
#ifndef LT_SIZE
|
|
#define LT_SIZE 16
|
|
#endif
|
|
|
|
#ifndef WRAM_LT_MAP16_STRIDE
|
|
#define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16)
|
|
#endif
|
|
|
|
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
|
__wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE];
|
|
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
|
|
|
#define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \
|
|
do { \
|
|
uint32_t align_num = 64 / src_dsize; \
|
|
uint32_t n = PAD_DOWN(size / src_dsize, align_num); \
|
|
uint32_t rem = size % 64; \
|
|
if (n) { \
|
|
__asm__ __volatile__( \
|
|
"move.tiling.async.nram.sram.b16" \
|
|
" [%[dst_addr]], [%[src_addr]], " \
|
|
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
|
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
|
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
|
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
|
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \
|
|
[src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \
|
|
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \
|
|
[src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \
|
|
[src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \
|
|
[dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \
|
|
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
|
} \
|
|
\
|
|
if (rem) { \
|
|
__asm__ __volatile__( \
|
|
"move.tiling.async.nram.sram.b16" \
|
|
" [%[dst_addr]], [%[src_addr]], " \
|
|
"%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \
|
|
"%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \
|
|
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \
|
|
"%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \
|
|
"%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \
|
|
[src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \
|
|
[src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \
|
|
[src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \
|
|
[dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \
|
|
[dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \
|
|
} \
|
|
} while (false)
|
|
|
|
__mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) {
|
|
#if __BANG_ARCH__ >= 500
|
|
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()");
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) {
|
|
#if __BANG_ARCH__ >= 500
|
|
SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) {
|
|
__memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM);
|
|
}
|
|
|
|
#define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \
|
|
convert_type) \
|
|
int align_n = PAD_DOWN(n, LT_NUM); \
|
|
int sn0 = ONE_LINE; \
|
|
int size_sn0 = sn0 / src_dsize; \
|
|
int sn1 = ONE_LINE / src_dsize; \
|
|
int ss1 = total_k * src_dsize; \
|
|
int sn3 = k / size_sn0; \
|
|
int sn4 = align_n / sn1; \
|
|
int ss4 = sn1 * ss1; \
|
|
int dn0 = sn0; \
|
|
int dn1 = ROW_PER_LT; \
|
|
int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \
|
|
int ds1 = dst_k * dst_dsize; \
|
|
int dn2 = sn1 / ROW_PER_LT; \
|
|
int ds2 = WRAM_LT_MAP16_STRIDE; \
|
|
int ds3 = sn0 * dst_dsize / src_dsize; \
|
|
int dn4 = LT_SIZE / dn2; \
|
|
int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \
|
|
int dn5 = align_n / LT_NUM; \
|
|
int ds5 = ROW_PER_LT * dst_k * dst_dsize; \
|
|
int rem_k = k % size_sn0; \
|
|
int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \
|
|
int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \
|
|
if (align_n > 0 && sn3 > 0) { \
|
|
__asm__ __volatile__( \
|
|
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
|
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
|
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
|
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
|
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
|
";\n\t" ::[dst_addr] "r"(wram_dst), \
|
|
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
|
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \
|
|
[dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \
|
|
[dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \
|
|
[dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \
|
|
sram_src += align_n * total_k; \
|
|
wram_dst += align_n / LT_SIZE * dst_k; \
|
|
} \
|
|
align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \
|
|
if (align_n > 0 && sn3 > 0) { \
|
|
sn1 = align_n; \
|
|
dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \
|
|
__asm__ __volatile__( \
|
|
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
|
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \
|
|
"%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
|
"1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \
|
|
[src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \
|
|
[src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \
|
|
[dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \
|
|
[dst_s3] "r"(ds3)); \
|
|
sram_src += align_n * total_k; \
|
|
wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \
|
|
} \
|
|
if (rem_k > 0) { \
|
|
align_n = PAD_UP(n, LT_NUM); \
|
|
sn0 = rem_k * src_dsize; \
|
|
dn0 = sn0; \
|
|
__asm__ __volatile__( \
|
|
"move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \
|
|
"%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \
|
|
"%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \
|
|
"%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \
|
|
"%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \
|
|
";\n\t" ::[dst_addr] "r"(wram_dst2), \
|
|
[src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \
|
|
[src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \
|
|
[src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \
|
|
[dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \
|
|
[dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \
|
|
[dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \
|
|
[dst_s5] "r"(0)); \
|
|
}
|
|
|
|
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
|
half *sram_src,
|
|
int32_t warp_n,
|
|
int32_t len_k,
|
|
int32_t total_k) {
|
|
#if __BANG_ARCH__ >= 500
|
|
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half),
|
|
".cvt.rn.f32.f16()");
|
|
#endif
|
|
}
|
|
|
|
__mlu_func__ void warp_prompt_weight(float *wram_dst,
|
|
bfloat16_t *sram_src,
|
|
int32_t warp_n,
|
|
int32_t len_k,
|
|
int32_t total_k) {
|
|
#if __BANG_ARCH__ >= 500
|
|
SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float),
|
|
sizeof(bfloat16_t), ".cvt.rn.f32.bf16()");
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void warp_prompt_weight(T *wram_dst,
|
|
T *sram_src,
|
|
int32_t n,
|
|
int32_t len_k,
|
|
int32_t total_k) {
|
|
int32_t type_size = sizeof(T);
|
|
int32_t data_size = len_k * type_size;
|
|
int32_t ds0 = PAD_UP(data_size, ONE_LINE);
|
|
int32_t ss0 = total_k * type_size;
|
|
int32_t count = n / LT_NUM;
|
|
for (int32_t i = 0; i < count; ++i) {
|
|
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
|
WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0);
|
|
wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0);
|
|
sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0);
|
|
}
|
|
count = n % LT_NUM / ROW_PER_LT;
|
|
if (count > 0) {
|
|
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1,
|
|
WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0);
|
|
wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE);
|
|
sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0);
|
|
}
|
|
count = n % ROW_PER_LT;
|
|
if (count > 0) {
|
|
__memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1);
|
|
}
|
|
}
|
|
|
|
__mlu_func__ void assignTaskEvenly(const int32_t num_total_task,
|
|
const int32_t &taskid,
|
|
const int32_t &taskdim,
|
|
int32_t &task_offset,
|
|
int32_t &num_cur_task) {
|
|
int32_t num_per_task = num_total_task / taskdim;
|
|
int32_t rem_idx = num_total_task % taskdim;
|
|
if (taskid < rem_idx) {
|
|
task_offset = taskid * (num_per_task + 1);
|
|
num_cur_task = num_per_task + 1;
|
|
} else {
|
|
task_offset = taskid * num_per_task + rem_idx;
|
|
num_cur_task = num_per_task;
|
|
}
|
|
}
|
|
|
|
__mlu_func__ void bidirectionBarrierOp() {
|
|
int32_t bcnt = coreDim + 1;
|
|
if (__is_ipu()) {
|
|
__asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
|
__asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
|
} else {
|
|
__asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
|
__asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt));
|
|
}
|
|
}
|
|
|
|
__mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) {
|
|
__bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n);
|
|
}
|
|
|
|
__mlu_func__ void warp_store(void *ddr_dst,
|
|
void *nram_src,
|
|
const int32_t data_num,
|
|
const int32_t dst_stride,
|
|
const int32_t src_stride,
|
|
const int32_t count,
|
|
const int32_t dt_size) {
|
|
if (src_stride == data_num && dst_stride == data_num) {
|
|
__memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM);
|
|
} else {
|
|
__memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size,
|
|
src_stride * dt_size, count - 1);
|
|
}
|
|
}
|
|
|
|
template <typename Tc, typename Tcc>
|
|
__mlu_func__ void splitKReduce(Tcc *workspace,
|
|
Tc *output,
|
|
int32_t M,
|
|
int32_t N,
|
|
int32_t split_k_num,
|
|
int32_t ldc) {
|
|
int32_t offset_m, cta_m;
|
|
assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m);
|
|
if (cta_m <= 0) return;
|
|
int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc);
|
|
int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0);
|
|
int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m;
|
|
Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N;
|
|
Tc *output_ddr = (Tc *)output + offset_m * ldc;
|
|
for (int32_t i = 0; i < repeat; i++) {
|
|
int32_t current_m = i == repeat - 1 ? rem_m : block_m;
|
|
int32_t data_size = N * sizeof(Tc);
|
|
int32_t data_num = current_m - 1;
|
|
if (ldc == N) {
|
|
data_size = current_m * N * sizeof(Tc);
|
|
data_num = 0;
|
|
}
|
|
__memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM,
|
|
current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1);
|
|
__bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1,
|
|
split_k_num, 1, 1, 1);
|
|
__memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc),
|
|
N * sizeof(Tc), data_num);
|
|
workspace_ddr = workspace_ddr + block_m * N;
|
|
output_ddr = output_ddr + block_m * ldc;
|
|
}
|
|
}
|
|
|
|
template <typename Ta,
|
|
typename Tac,
|
|
typename Tb,
|
|
typename Tbc,
|
|
typename Tc,
|
|
typename Tcc,
|
|
bool EXCHANGE_AB>
|
|
__mlu_global__ void MLUCastGating(Ta *A,
|
|
Tb *B,
|
|
Tc *C,
|
|
Tcc *workspace,
|
|
int32_t M,
|
|
int32_t N,
|
|
int32_t K,
|
|
int32_t lda,
|
|
int32_t ldb,
|
|
int32_t ldc,
|
|
castGatingTileInfo split_info) {
|
|
#if __BANG_ARCH__ >= 500
|
|
int32_t block_k = split_info.block_k;
|
|
int32_t grid_dimx = split_info.split_k_num;
|
|
int32_t block = split_info.block;
|
|
int32_t grid_idx = clusterId % grid_dimx;
|
|
int32_t grid_idy = clusterId / grid_dimx;
|
|
int32_t offset_k = 0, problem_k = 0;
|
|
assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k);
|
|
int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k;
|
|
int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0);
|
|
int32_t cta_k = k_loop == 1 ? rem_k : block_k;
|
|
int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0;
|
|
int32_t warp_m = cta_m, warp_offset_m = 0;
|
|
int32_t warp_n = cta_n, warp_offset_n = 0;
|
|
int32_t outer_loop = 0, outer_rem = 0;
|
|
if (EXCHANGE_AB) {
|
|
assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n);
|
|
assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n);
|
|
if (cta_n > block && cta_n % block != 0) {
|
|
int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM);
|
|
if (block_tmp < block) block = block_tmp;
|
|
}
|
|
outer_loop = (cta_n + block - 1) / block;
|
|
outer_rem = cta_n % block == 0 ? block : cta_n % block;
|
|
} else {
|
|
assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m);
|
|
assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m);
|
|
if (cta_m > block && cta_m % block != 0) {
|
|
int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim);
|
|
if (block_tmp < block) block = block_tmp;
|
|
}
|
|
outer_loop = (cta_m + block - 1) / block;
|
|
outer_rem = cta_m % block == 0 ? block : cta_m % block;
|
|
}
|
|
|
|
int32_t size_nram_buf =
|
|
NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB));
|
|
int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac);
|
|
Tac *nbuf_a = (Tac *)nram_buffer;
|
|
Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf);
|
|
Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c;
|
|
|
|
int32_t size_sram_buf = SRAM_BUFFER_SIZE;
|
|
int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta);
|
|
int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb);
|
|
Ta *sbuf_a = (Ta *)sram_buffer;
|
|
Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k));
|
|
|
|
int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc);
|
|
Tbc *wbuf_b = (Tbc *)wram_buffer;
|
|
|
|
int32_t a_dsize = sizeof(Ta);
|
|
int32_t b_dsize = sizeof(Tb);
|
|
int32_t k_loop_count = 0;
|
|
for (int32_t j = 0; j < outer_loop; j++) {
|
|
Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB);
|
|
Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB);
|
|
int32_t current_block = j == outer_loop - 1 ? outer_rem : block;
|
|
if (EXCHANGE_AB) {
|
|
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n);
|
|
} else {
|
|
assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m);
|
|
}
|
|
int32_t compute_total = warp_m * warp_n;
|
|
if (compute_total > 0 && __is_ipu()) {
|
|
if (!EXCHANGE_AB) {
|
|
__sync_io_move_compute(true, false, false, false, false, true);
|
|
}
|
|
__bang_write_zero((Tcc *)nbuf_c, compute_total);
|
|
}
|
|
int32_t i = 0;
|
|
for (; i < k_loop; i++) {
|
|
Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a;
|
|
Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b;
|
|
cta_k = i == k_loop - 1 ? rem_k : block_k;
|
|
if (__is_mpu()) {
|
|
if (EXCHANGE_AB) {
|
|
__memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize,
|
|
current_block - 1);
|
|
__asm__ volatile(
|
|
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
|
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a),
|
|
[src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize),
|
|
[src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1));
|
|
} else {
|
|
__memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize,
|
|
current_block - 1);
|
|
__asm__ volatile(
|
|
"ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], "
|
|
"%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b),
|
|
[src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize),
|
|
[src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1));
|
|
}
|
|
a_ddr = (Ta *)a_ddr + block_k;
|
|
b_ddr = (Tb *)b_ddr + block_k;
|
|
}
|
|
bidirectionBarrierOp();
|
|
if (__is_ipu() && compute_total > 0) {
|
|
__sync_io_move_compute(false, true, false, false, false, true);
|
|
__sync_io_move_compute(false, false, true, false, true, false);
|
|
if (i >= 1) {
|
|
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
|
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k);
|
|
}
|
|
warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram,
|
|
sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta));
|
|
// mvdma bound for EXCHANGE_AB when n==32
|
|
warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram,
|
|
(Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k);
|
|
}
|
|
k_loop_count += 1;
|
|
}
|
|
if (compute_total > 0) {
|
|
__sync_io_move_compute(false, true, false, false, false, true);
|
|
__wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram,
|
|
(Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k);
|
|
if (EXCHANGE_AB) {
|
|
__sync_io_move_compute(true, false, false, false, false, true);
|
|
__bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n);
|
|
}
|
|
int32_t total_offset =
|
|
grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M
|
|
: (offset_m + warp_offset_m + block * j) * N);
|
|
Tcc *wks = (Tcc *)workspace + total_offset;
|
|
int32_t store_c_size = sizeof(Tcc);
|
|
int8_t *store_ddr = (int8_t *)wks;
|
|
int32_t dst_str = EXCHANGE_AB ? M : N;
|
|
if (grid_dimx == 1) {
|
|
// convert Tcc to Tc
|
|
dst_str = ldc;
|
|
store_ddr =
|
|
(int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc
|
|
: (offset_m + warp_offset_m + block * j) * ldc));
|
|
}
|
|
__asm__ volatile("sync.psimd.cio;\n\t");
|
|
if (EXCHANGE_AB) {
|
|
warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size);
|
|
} else {
|
|
warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size);
|
|
}
|
|
}
|
|
}
|
|
if (grid_dimx != 1) {
|
|
__sync_all();
|
|
splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N,
|
|
split_info.split_k_num, ldc);
|
|
}
|
|
#endif // __BANG_ARCH__ >= 500
|
|
}
|
|
} // namespace kernels
|
|
|
|
int32_t getBlock(int32_t m,
|
|
int32_t n,
|
|
int32_t core_num,
|
|
int32_t block_k,
|
|
int32_t a_dtype_size,
|
|
int32_t b_dtype_size,
|
|
int32_t compute_dtype_size,
|
|
bool EXCHANGE_AB) {
|
|
int32_t block = 0;
|
|
if (EXCHANGE_AB) {
|
|
int32_t block_m = n;
|
|
int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) /
|
|
(2 * block_m * compute_dtype_size) * core_num;
|
|
int32_t wram_block_n =
|
|
WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num;
|
|
int32_t sram_block_n =
|
|
(SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2);
|
|
int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n);
|
|
int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM);
|
|
return block_n > 0 ? block_n : block_n_tmp;
|
|
} else {
|
|
int32_t block_n = n;
|
|
int32_t nram_block_m =
|
|
NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2);
|
|
int32_t sram_block_m =
|
|
(SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2);
|
|
block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num));
|
|
return block;
|
|
}
|
|
}
|
|
|
|
void gatingTiling(int32_t m,
|
|
int32_t n,
|
|
int32_t k,
|
|
size_t a_dtype_size,
|
|
size_t b_dtype_size,
|
|
size_t compute_dtype_size,
|
|
size_t workspace_size,
|
|
int32_t union_number,
|
|
int32_t core_num,
|
|
int32_t &block,
|
|
int32_t &split_k_num,
|
|
int32_t &block_k,
|
|
bool &EXCHANGE_AB) {
|
|
block_k = std::min(k, int32_t(512 / a_dtype_size));
|
|
split_k_num = 1;
|
|
// swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian
|
|
if (m >= core_num * LT_NUM &&
|
|
float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) {
|
|
EXCHANGE_AB = true;
|
|
}
|
|
int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size,
|
|
compute_dtype_size, EXCHANGE_AB);
|
|
int32_t total_blocks = DIV_UP((size_t)m, tmp_block);
|
|
block = tmp_block;
|
|
if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) {
|
|
for (int32_t i = total_blocks; i <= union_number; i++) {
|
|
if (union_number % i == 0) {
|
|
int32_t tmp_split_k = union_number / i;
|
|
size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size;
|
|
if (workspace_size >= workspace_size_need) {
|
|
split_k_num = tmp_split_k;
|
|
block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block);
|
|
if (EXCHANGE_AB && block > LT_NUM * core_num) {
|
|
block = PAD_DOWN(block, LT_NUM * core_num);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void getContxtInfo(int32_t *union_number, int32_t *core_num) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev));
|
|
}
|
|
|
|
KernelStatus invokeCastGating(cnrtQueue_t queue,
|
|
void *input,
|
|
void *filter,
|
|
void *output,
|
|
int input_row,
|
|
int expert_num,
|
|
int hidden_size,
|
|
cnnlDataType_t a_dtype,
|
|
void *workspace,
|
|
size_t workspace_size_bytes) {
|
|
if (is_arch300()) {
|
|
std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if (expert_num > 128) {
|
|
std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) {
|
|
std::cerr
|
|
<< "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
if (workspace_size_bytes > 0 && workspace == NULL) {
|
|
std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is "
|
|
"greater than 0."
|
|
<< std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
|
|
int32_t union_number, core_num;
|
|
getContxtInfo(&union_number, &core_num);
|
|
cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num);
|
|
cnrtDim3_t dim;
|
|
dim.x = (int32_t)func_type;
|
|
dim.y = 1;
|
|
dim.z = 1;
|
|
|
|
cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT;
|
|
cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT;
|
|
size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0;
|
|
cnnlGetSizeOfDataType(a_dtype, &a_dtype_size);
|
|
cnnlGetSizeOfDataType(b_dtype, &b_dtype_size);
|
|
cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size);
|
|
castGatingTileInfo split_info;
|
|
bool EXCHANGE_AB = false;
|
|
gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size,
|
|
workspace_size_bytes, union_number, core_num, split_info.block,
|
|
split_info.split_k_num, split_info.block_k, EXCHANGE_AB);
|
|
if (a_dtype == CNNL_DTYPE_BFLOAT16) {
|
|
if (EXCHANGE_AB) {
|
|
kernels::MLUCastGating<float, float, bfloat16_t, float, float, float, true>
|
|
<<<dim, func_type, queue>>>((float *)filter, (bfloat16_t *)input, (float *)output,
|
|
(float *)workspace, expert_num, input_row, hidden_size,
|
|
hidden_size, hidden_size, expert_num, split_info);
|
|
} else {
|
|
kernels::MLUCastGating<bfloat16_t, float, float, float, float, float, false>
|
|
<<<dim, func_type, queue>>>((bfloat16_t *)input, (float *)filter, (float *)output,
|
|
(float *)workspace, input_row, expert_num, hidden_size,
|
|
hidden_size, hidden_size, expert_num, split_info);
|
|
}
|
|
} else if (a_dtype == CNNL_DTYPE_HALF) {
|
|
if (EXCHANGE_AB) {
|
|
kernels::MLUCastGating<float, float, half, float, float, float, true>
|
|
<<<dim, func_type, queue>>>((float *)filter, (half *)input, (float *)output,
|
|
(float *)workspace, expert_num, input_row, hidden_size,
|
|
hidden_size, hidden_size, expert_num, split_info);
|
|
} else {
|
|
kernels::MLUCastGating<half, float, float, float, float, float, false>
|
|
<<<dim, func_type, queue>>>((half *)input, (float *)filter, (float *)output,
|
|
(float *)workspace, input_row, expert_num, hidden_size,
|
|
hidden_size, hidden_size, expert_num, split_info);
|
|
}
|
|
} else {
|
|
std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
} // namespace tmo
|