#include #include #include #include #include #include "cast_gating.mluh" #include "cnnl.h" #include "cnrt.h" // clang-format off #include // clang-format on namespace tmo { #define DIV_UP(x, y) ((x) % (y) > 0 ? ((x) / (y) + 1) : ((x) / (y))) #define NRAM_BUFFER_SIZE (496 * 1024) #define WRAM_BUFFER_SIZE (512 * 1024) #define SRAM_BUFFER_SIZE (2032 * 1024) #ifndef ONE_LINE #define ONE_LINE 64 #endif #ifndef LT_NUM #define LT_NUM 64 #endif struct castGatingTileInfo { int32_t block = 64; int32_t split_k_num = 8; int32_t block_k = 256; }; namespace kernels { #pragma bang walign(16) #ifndef ROW_PER_LT #define ROW_PER_LT 4 #endif #ifndef LT_SIZE #define LT_SIZE 16 #endif #ifndef WRAM_LT_MAP16_STRIDE #define WRAM_LT_MAP16_STRIDE (WRAM_BUFFER_SIZE / 16) #endif __nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE]; __wram__ int8_t wram_buffer[WRAM_BUFFER_SIZE]; __mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE]; #define SRAM2NRAM_CONVERT_IMPL(dst, src, size, dst_dsize, src_dsize, convert_type) \ do { \ uint32_t align_num = 64 / src_dsize; \ uint32_t n = PAD_DOWN(size / src_dsize, align_num); \ uint32_t rem = size % 64; \ if (n) { \ __asm__ __volatile__( \ "move.tiling.async.nram.sram.b16" \ " [%[dst_addr]], [%[src_addr]], " \ "%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \ "%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \ "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \ "%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \ "%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst), \ [src_addr] "r"(src), [src_n0] "i"(64), [src_n1] "i"(1), [src_s1] "i"(0), \ [src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "r"(n / align_num), \ [src_s3] "r"(align_num * src_dsize), [src_n4] "i"(1), [src_s4] "i"(0), [src_n5] "i"(1), \ [src_s5] "i"(0), [dst_n0] "i"(64), [dst_n1] "i"(1), [dst_s1] "i"(0), [dst_n2] "i"(1), \ [dst_s2] "i"(0), [dst_n3] "r"(n / align_num), [dst_s3] "r"(align_num * dst_dsize), \ [dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \ } \ \ if (rem) { \ __asm__ __volatile__( \ "move.tiling.async.nram.sram.b16" \ " [%[dst_addr]], [%[src_addr]], " \ "%[src_n0], %[src_n1], %[src_s1], %[src_n2], %[src_s2], " \ "%[src_n3], %[src_s3], %[src_n4], %[src_s4], %[src_n5], %[src_s5], " \ "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], " \ "%[dst_n3], %[dst_s3], %[dst_n4], %[dst_s4]," \ "%[dst_n5], %[dst_s5]," convert_type ";\n\t" ::[dst_addr] "r"(dst + n), \ [src_addr] "r"(src + n), [src_n0] "r"(rem), [src_n1] "i"(1), [src_s1] "i"(0), \ [src_n2] "i"(1), [src_s2] "i"(0), [src_n3] "i"(1), [src_s3] "i"(0), [src_n4] "i"(1), \ [src_s4] "i"(0), [src_n5] "i"(1), [src_s5] "i"(0), [dst_n0] "r"(rem), [dst_n1] "i"(1), \ [dst_s1] "i"(0), [dst_n2] "i"(1), [dst_s2] "i"(0), [dst_n3] "i"(1), [dst_s3] "i"(0), \ [dst_n4] "i"(1), [dst_s4] "i"(0), [dst_n5] "i"(1), [dst_s5] "i"(0)); \ } \ } while (false) __mlu_func__ void warp_prompt_input(float *dst, half *src, int32_t size) { #if __BANG_ARCH__ >= 500 SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()"); #endif } __mlu_func__ void warp_prompt_input(float *dst, bfloat16_t *src, int32_t size) { #if __BANG_ARCH__ >= 500 SRAM2NRAM_CONVERT_IMPL(dst, src, size, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()"); #endif } __mlu_func__ void warp_prompt_input(float *dst, float *src, int32_t size) { __memcpy_async((float *)dst, (float *)src, size, SRAM2NRAM); } #define SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, n, k, total_k, dst_dsize, src_dsize, \ convert_type) \ int align_n = PAD_DOWN(n, LT_NUM); \ int sn0 = ONE_LINE; \ int size_sn0 = sn0 / src_dsize; \ int sn1 = ONE_LINE / src_dsize; \ int ss1 = total_k * src_dsize; \ int sn3 = k / size_sn0; \ int sn4 = align_n / sn1; \ int ss4 = sn1 * ss1; \ int dn0 = sn0; \ int dn1 = ROW_PER_LT; \ int dst_k = PAD_UP(k, ONE_LINE / dst_dsize); \ int ds1 = dst_k * dst_dsize; \ int dn2 = sn1 / ROW_PER_LT; \ int ds2 = WRAM_LT_MAP16_STRIDE; \ int ds3 = sn0 * dst_dsize / src_dsize; \ int dn4 = LT_SIZE / dn2; \ int ds4 = dn2 * WRAM_LT_MAP16_STRIDE; \ int dn5 = align_n / LT_NUM; \ int ds5 = ROW_PER_LT * dst_k * dst_dsize; \ int rem_k = k % size_sn0; \ int8_t *sram_src2 = (int8_t *)sram_src + sn3 * size_sn0 * src_dsize; \ int8_t *wram_dst2 = (int8_t *)wram_dst + sn3 * size_sn0 * dst_dsize; \ if (align_n > 0 && sn3 > 0) { \ __asm__ __volatile__( \ "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \ "%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \ "%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ "%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \ ";\n\t" ::[dst_addr] "r"(wram_dst), \ [src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \ [src_n3] "r"(sn3), [src_s3] "r"(sn0), [src_n4] "r"(sn4), [src_s4] "r"(ss4), \ [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), \ [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), [dst_s3] "r"(ds3), [dst_n4] "r"(dn4), \ [dst_s4] "r"(ds4), [dst_n5] "r"(dn5), [dst_s5] "r"(ds5)); \ sram_src += align_n * total_k; \ wram_dst += align_n / LT_SIZE * dst_k; \ } \ align_n = PAD_UP(n % LT_NUM, ROW_PER_LT); \ if (align_n > 0 && sn3 > 0) { \ sn1 = align_n; \ dn2 = (sn1 + ROW_PER_LT - 1) / ROW_PER_LT; \ __asm__ __volatile__( \ "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], 1, 0, 1, 0, " \ "%[dst_n0], %[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ "1, 0, 1, 0, " convert_type ";\n\t" ::[dst_addr] "r"(wram_dst), \ [src_addr] "r"(sram_src), [src_n0] "r"(sn0), [src_n1] "r"(sn1), [src_s1] "r"(ss1), \ [src_n3] "r"(sn3), [src_s3] "r"(sn0), [dst_n0] "r"(dn0), [dst_n1] "r"(dn1), \ [dst_s1] "r"(ds1), [dst_n2] "r"(dn2), [dst_s2] "r"(ds2), [dst_n3] "r"(sn3), \ [dst_s3] "r"(ds3)); \ sram_src += align_n * total_k; \ wram_dst += align_n / ROW_PER_LT * WRAM_LT_MAP16_STRIDE / dst_dsize; \ } \ if (rem_k > 0) { \ align_n = PAD_UP(n, LT_NUM); \ sn0 = rem_k * src_dsize; \ dn0 = sn0; \ __asm__ __volatile__( \ "move.tiling.async.wram.sram.b16 [%[dst_addr]], [%[src_addr]], %[src_n0], " \ "%[src_n1], %[src_s1], 1, 0, %[src_n3], %[src_s3], " \ "%[src_n4], %[src_s4], 1, 0, %[dst_n0], " \ "%[dst_n1], %[dst_s1], %[dst_n2], %[dst_s2], %[dst_n3], %[dst_s3], " \ "%[dst_n4], %[dst_s4], %[dst_n5], %[dst_s5], " convert_type \ ";\n\t" ::[dst_addr] "r"(wram_dst2), \ [src_addr] "r"(sram_src2), [src_n0] "r"(sn0), [src_n1] "r"(ROW_PER_LT), [src_s1] "r"(ss1), \ [src_n3] "r"(LT_NUM / ROW_PER_LT), [src_s3] "r"(ROW_PER_LT * ss1), \ [src_n4] "r"(align_n / LT_NUM), [src_s4] "r"(LT_NUM * ss1), [dst_n0] "r"(dn0), \ [dst_n1] "r"(ROW_PER_LT), [dst_s1] "r"(ds1), [dst_n2] "r"(1), [dst_s2] "r"(0), \ [dst_n3] "r"(LT_NUM / ROW_PER_LT), [dst_s3] "r"(WRAM_LT_MAP16_STRIDE), \ [dst_n4] "r"(align_n / LT_NUM), [dst_s4] "r"(ROW_PER_LT * ds1), [dst_n5] "r"(1), \ [dst_s5] "r"(0)); \ } __mlu_func__ void warp_prompt_weight(float *wram_dst, half *sram_src, int32_t warp_n, int32_t len_k, int32_t total_k) { #if __BANG_ARCH__ >= 500 SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(half), ".cvt.rn.f32.f16()"); #endif } __mlu_func__ void warp_prompt_weight(float *wram_dst, bfloat16_t *sram_src, int32_t warp_n, int32_t len_k, int32_t total_k) { #if __BANG_ARCH__ >= 500 SRAM2WRAM_CONVERT_IMPL(wram_dst, sram_src, warp_n, len_k, total_k, sizeof(float), sizeof(bfloat16_t), ".cvt.rn.f32.bf16()"); #endif } template __mlu_func__ void warp_prompt_weight(T *wram_dst, T *sram_src, int32_t n, int32_t len_k, int32_t total_k) { int32_t type_size = sizeof(T); int32_t data_size = len_k * type_size; int32_t ds0 = PAD_UP(data_size, ONE_LINE); int32_t ss0 = total_k * type_size; int32_t count = n / LT_NUM; for (int32_t i = 0; i < count; ++i) { __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1, WRAM_LT_MAP16_STRIDE, LT_SIZE - 1, ss0, LT_NUM - 1, 0, 0); wram_dst = (T *)((int8_t *)wram_dst + ROW_PER_LT * ds0); sram_src = (T *)((int8_t *)sram_src + LT_NUM * ss0); } count = n % LT_NUM / ROW_PER_LT; if (count > 0) { __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ROW_PER_LT - 1, WRAM_LT_MAP16_STRIDE, count - 1, ss0, count * ROW_PER_LT - 1, 0, 0); wram_dst = (T *)((int8_t *)wram_dst + count * WRAM_LT_MAP16_STRIDE); sram_src = (T *)((int8_t *)sram_src + count * ROW_PER_LT * ss0); } count = n % ROW_PER_LT; if (count > 0) { __memcpy_async(wram_dst, sram_src, data_size, SRAM2WRAM, ds0, ss0, count - 1); } } __mlu_func__ void assignTaskEvenly(const int32_t num_total_task, const int32_t &taskid, const int32_t &taskdim, int32_t &task_offset, int32_t &num_cur_task) { int32_t num_per_task = num_total_task / taskdim; int32_t rem_idx = num_total_task % taskdim; if (taskid < rem_idx) { task_offset = taskid * (num_per_task + 1); num_cur_task = num_per_task + 1; } else { task_offset = taskid * num_per_task + rem_idx; num_cur_task = num_per_task; } } __mlu_func__ void bidirectionBarrierOp() { int32_t bcnt = coreDim + 1; if (__is_ipu()) { __asm__ __volatile__("barrier.arrive.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt)); __asm__ __volatile__("barrier.sync.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt)); } else { __asm__ __volatile__("barrier.sync.local.pmv.cio 5, %[cnt];\n\t" ::[cnt] "r"(bcnt)); __asm__ __volatile__("barrier.arrive.local.pio.cmv 3, %[cnt];\n\t" ::[cnt] "r"(bcnt)); } } __mlu_func__ void __wmma(float *c, float *a, float *b, int32_t m, int32_t n, int32_t k) { __bang_conv_partial((float *)c, (float *)a, (float *)b, (float *)c, k, m, 1, 1, 1, 1, 1, n); } __mlu_func__ void warp_store(void *ddr_dst, void *nram_src, const int32_t data_num, const int32_t dst_stride, const int32_t src_stride, const int32_t count, const int32_t dt_size) { if (src_stride == data_num && dst_stride == data_num) { __memcpy_async(ddr_dst, nram_src, count * data_num * dt_size, NRAM2GDRAM); } else { __memcpy_async(ddr_dst, nram_src, data_num * dt_size, NRAM2GDRAM, (size_t)dst_stride * dt_size, src_stride * dt_size, count - 1); } } template __mlu_func__ void splitKReduce(Tcc *workspace, Tc *output, int32_t M, int32_t N, int32_t split_k_num, int32_t ldc) { int32_t offset_m, cta_m; assignTaskEvenly(M, taskId, taskDim, offset_m, cta_m); if (cta_m <= 0) return; int32_t block_m = NRAM_BUFFER_SIZE / split_k_num / N / sizeof(Tcc); int32_t repeat = cta_m / block_m + int32_t(cta_m % block_m != 0); int32_t rem_m = cta_m % block_m != 0 ? cta_m % block_m : block_m; Tcc *workspace_ddr = (Tcc *)workspace + offset_m * N; Tc *output_ddr = (Tc *)output + offset_m * ldc; for (int32_t i = 0; i < repeat; i++) { int32_t current_m = i == repeat - 1 ? rem_m : block_m; int32_t data_size = N * sizeof(Tc); int32_t data_num = current_m - 1; if (ldc == N) { data_size = current_m * N * sizeof(Tc); data_num = 0; } __memcpy((Tcc *)nram_buffer, (Tcc *)workspace_ddr, current_m * N * sizeof(Tcc), GDRAM2NRAM, current_m * N * sizeof(Tcc), M * N * sizeof(Tcc), split_k_num - 1); __bang_sumpool((Tcc *)nram_buffer, (Tcc *)nram_buffer, current_m * N, split_k_num, 1, split_k_num, 1, 1, 1); __memcpy((Tc *)output_ddr, (Tc *)nram_buffer, data_size, NRAM2GDRAM, ldc * sizeof(Tc), N * sizeof(Tc), data_num); workspace_ddr = workspace_ddr + block_m * N; output_ddr = output_ddr + block_m * ldc; } } template __mlu_global__ void MLUCastGating(Ta *A, Tb *B, Tc *C, Tcc *workspace, int32_t M, int32_t N, int32_t K, int32_t lda, int32_t ldb, int32_t ldc, castGatingTileInfo split_info) { #if __BANG_ARCH__ >= 500 int32_t block_k = split_info.block_k; int32_t grid_dimx = split_info.split_k_num; int32_t block = split_info.block; int32_t grid_idx = clusterId % grid_dimx; int32_t grid_idy = clusterId / grid_dimx; int32_t offset_k = 0, problem_k = 0; assignTaskEvenly(K, grid_idx, grid_dimx, offset_k, problem_k); int32_t rem_k = problem_k % block_k > 0 ? problem_k % block_k : block_k; int32_t k_loop = problem_k / block_k + (int32_t)(problem_k % block_k > 0); int32_t cta_k = k_loop == 1 ? rem_k : block_k; int32_t cta_m = M, offset_m = 0, cta_n = N, offset_n = 0; int32_t warp_m = cta_m, warp_offset_m = 0; int32_t warp_n = cta_n, warp_offset_n = 0; int32_t outer_loop = 0, outer_rem = 0; if (EXCHANGE_AB) { assignTaskEvenly(N, grid_idy, clusterDim / grid_dimx, offset_n, cta_n); assignTaskEvenly(block, coreId, coreDim, warp_offset_n, warp_n); if (cta_n > block && cta_n % block != 0) { int32_t block_tmp = PAD_UP((cta_n + cta_n / block) / (cta_n / block + 1), coreDim * LT_NUM); if (block_tmp < block) block = block_tmp; } outer_loop = (cta_n + block - 1) / block; outer_rem = cta_n % block == 0 ? block : cta_n % block; } else { assignTaskEvenly(M, grid_idy, clusterDim / grid_dimx, offset_m, cta_m); assignTaskEvenly(block, coreId, coreDim, warp_offset_m, warp_m); if (cta_m > block && cta_m % block != 0) { int32_t block_tmp = PAD_UP((cta_m + cta_m / block) / (cta_m / block + 1), coreDim); if (block_tmp < block) block = block_tmp; } outer_loop = (cta_m + block - 1) / block; outer_rem = cta_m % block == 0 ? block : cta_m % block; } int32_t size_nram_buf = NRAM_BUFFER_SIZE - warp_m * warp_n * sizeof(Tcc) * (1 + int32_t(EXCHANGE_AB)); int32_t pong_a_nram = size_nram_buf / 2 / sizeof(Tac); Tac *nbuf_a = (Tac *)nram_buffer; Tcc *nbuf_c = (Tcc *)(nram_buffer + size_nram_buf); Tcc *nbuf_out = EXCHANGE_AB ? (Tcc *)nbuf_c + warp_m * warp_n : nbuf_c; int32_t size_sram_buf = SRAM_BUFFER_SIZE; int32_t pong_sram_a = size_sram_buf / 2 / sizeof(Ta); int32_t pong_sram_b = size_sram_buf / 2 / sizeof(Tb); Ta *sbuf_a = (Ta *)sram_buffer; Tb *sbuf_b = (Tb *)((Ta *)sram_buffer + (EXCHANGE_AB ? M * block_k : block * block_k)); int32_t pong_b_wram = WRAM_LT_MAP16_STRIDE / 2 / sizeof(Tbc); Tbc *wbuf_b = (Tbc *)wram_buffer; int32_t a_dsize = sizeof(Ta); int32_t b_dsize = sizeof(Tb); int32_t k_loop_count = 0; for (int32_t j = 0; j < outer_loop; j++) { Ta *a_ddr = (Ta *)A + offset_k + ((size_t)offset_m + j * block) * lda * int(!EXCHANGE_AB); Tb *b_ddr = (Tb *)B + offset_k + ((size_t)offset_n + j * block) * ldb * int(EXCHANGE_AB); int32_t current_block = j == outer_loop - 1 ? outer_rem : block; if (EXCHANGE_AB) { assignTaskEvenly(current_block, coreId, coreDim, warp_offset_n, warp_n); } else { assignTaskEvenly(current_block, coreId, coreDim, warp_offset_m, warp_m); } int32_t compute_total = warp_m * warp_n; if (compute_total > 0 && __is_ipu()) { if (!EXCHANGE_AB) { __sync_io_move_compute(true, false, false, false, false, true); } __bang_write_zero((Tcc *)nbuf_c, compute_total); } int32_t i = 0; for (; i < k_loop; i++) { Ta *sram_a = (Ta *)sbuf_a + k_loop_count % 2 * pong_sram_a; Tb *sram_b = (Tb *)sbuf_b + k_loop_count % 2 * pong_sram_b; cta_k = i == k_loop - 1 ? rem_k : block_k; if (__is_mpu()) { if (EXCHANGE_AB) { __memcpy_async(sram_b, b_ddr, cta_k * b_dsize, GDRAM2SRAM, cta_k * b_dsize, ldb * b_dsize, current_block - 1); __asm__ volatile( "ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], " "%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_a), [src] "r"(a_ddr), [size] "r"(cta_k * a_dsize), [dst_strd] "r"(cta_k * a_dsize), [src_strd] "r"(lda * a_dsize), [segnum] "r"(M - 1)); } else { __memcpy_async(sram_a, a_ddr, cta_k * a_dsize, GDRAM2SRAM, cta_k * a_dsize, lda * a_dsize, current_block - 1); __asm__ volatile( "ld.async.stride.sram.gdram.scmnormal [%[dst]], [%[src]], %[size], %[dst_strd], " "%[src_strd], %[segnum];\n\t" ::[dst] "r"(sram_b), [src] "r"(b_ddr), [size] "r"(cta_k * b_dsize), [dst_strd] "r"(cta_k * b_dsize), [src_strd] "r"(ldb * b_dsize), [segnum] "r"(N - 1)); } a_ddr = (Ta *)a_ddr + block_k; b_ddr = (Tb *)b_ddr + block_k; } bidirectionBarrierOp(); if (__is_ipu() && compute_total > 0) { __sync_io_move_compute(false, true, false, false, false, true); __sync_io_move_compute(false, false, true, false, true, false); if (i >= 1) { __wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram, (Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, block_k); } warp_prompt_input((Tac *)nbuf_a + k_loop_count % 2 * pong_a_nram, sram_a + cta_k * warp_offset_m, warp_m * cta_k * sizeof(Ta)); // mvdma bound for EXCHANGE_AB when n==32 warp_prompt_weight((Tbc *)wbuf_b + k_loop_count % 2 * pong_b_wram, (Tb *)sram_b + cta_k * warp_offset_n, warp_n, cta_k, cta_k); } k_loop_count += 1; } if (compute_total > 0) { __sync_io_move_compute(false, true, false, false, false, true); __wmma(nbuf_c, (Tac *)nbuf_a + (k_loop_count - 1) % 2 * pong_a_nram, (Tbc *)wbuf_b + (k_loop_count - 1) % 2 * pong_b_wram, warp_m, warp_n, rem_k); if (EXCHANGE_AB) { __sync_io_move_compute(true, false, false, false, false, true); __bang_transpose((Tcc *)nbuf_out, (Tcc *)nbuf_c, warp_m, warp_n); } int32_t total_offset = grid_idx * M * N + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * M : (offset_m + warp_offset_m + block * j) * N); Tcc *wks = (Tcc *)workspace + total_offset; int32_t store_c_size = sizeof(Tcc); int8_t *store_ddr = (int8_t *)wks; int32_t dst_str = EXCHANGE_AB ? M : N; if (grid_dimx == 1) { // convert Tcc to Tc dst_str = ldc; store_ddr = (int8_t *)((Tc *)C + (EXCHANGE_AB ? (offset_n + warp_offset_n + block * j) * ldc : (offset_m + warp_offset_m + block * j) * ldc)); } __asm__ volatile("sync.psimd.cio;\n\t"); if (EXCHANGE_AB) { warp_store(store_ddr, (Tcc *)nbuf_out, warp_m, dst_str, warp_m, warp_n, store_c_size); } else { warp_store(store_ddr, (Tcc *)nbuf_out, warp_n, dst_str, warp_n, warp_m, store_c_size); } } } if (grid_dimx != 1) { __sync_all(); splitKReduce((Tcc *)workspace, (Tc *)C, EXCHANGE_AB ? N : M, EXCHANGE_AB ? M : N, split_info.split_k_num, ldc); } #endif // __BANG_ARCH__ >= 500 } } // namespace kernels int32_t getBlock(int32_t m, int32_t n, int32_t core_num, int32_t block_k, int32_t a_dtype_size, int32_t b_dtype_size, int32_t compute_dtype_size, bool EXCHANGE_AB) { int32_t block = 0; if (EXCHANGE_AB) { int32_t block_m = n; int32_t nram_block_n = (NRAM_BUFFER_SIZE - block_m * block_k * compute_dtype_size * 2) / (2 * block_m * compute_dtype_size) * core_num; int32_t wram_block_n = WRAM_BUFFER_SIZE / 2 / PAD_UP(block_k * compute_dtype_size, 64) * core_num; int32_t sram_block_n = (SRAM_BUFFER_SIZE - block_m * block_k * a_dtype_size * 2) / (block_k * b_dtype_size * 2); int32_t block_n_tmp = std::min(std::min(nram_block_n, wram_block_n), sram_block_n); int32_t block_n = PAD_DOWN(block_n_tmp, core_num * LT_NUM); return block_n > 0 ? block_n : block_n_tmp; } else { int32_t block_n = n; int32_t nram_block_m = NRAM_BUFFER_SIZE / (block_n * compute_dtype_size + block_k * compute_dtype_size * 2); int32_t sram_block_m = (SRAM_BUFFER_SIZE - block_n * block_k * b_dtype_size * 2) / (block_k * a_dtype_size * 2); block = std::min(nram_block_m * core_num, PAD_DOWN(sram_block_m, core_num)); return block; } } void gatingTiling(int32_t m, int32_t n, int32_t k, size_t a_dtype_size, size_t b_dtype_size, size_t compute_dtype_size, size_t workspace_size, int32_t union_number, int32_t core_num, int32_t &block, int32_t &split_k_num, int32_t &block_k, bool &EXCHANGE_AB) { block_k = std::min(k, int32_t(512 / a_dtype_size)); split_k_num = 1; // swap A and B to reduce computing waste caused by LT_NUM-align of co dimensian if (m >= core_num * LT_NUM && float(m) / float(PAD_UP((size_t)m, LT_NUM)) > float(n) / float(PAD_UP(n, LT_NUM))) { EXCHANGE_AB = true; } int32_t tmp_block = getBlock(m, n, core_num, block_k, a_dtype_size, b_dtype_size, compute_dtype_size, EXCHANGE_AB); int32_t total_blocks = DIV_UP((size_t)m, tmp_block); block = tmp_block; if (total_blocks < union_number && (size_t)k * a_dtype_size > 512 * union_number) { for (int32_t i = total_blocks; i <= union_number; i++) { if (union_number % i == 0) { int32_t tmp_split_k = union_number / i; size_t workspace_size_need = (size_t)tmp_split_k * m * n * compute_dtype_size; if (workspace_size >= workspace_size_need) { split_k_num = tmp_split_k; block = std::min(((size_t)m + total_blocks - 1) / total_blocks, (size_t)tmp_block); if (EXCHANGE_AB && block > LT_NUM * core_num) { block = PAD_DOWN(block, LT_NUM * core_num); } break; } } } } } void getContxtInfo(int32_t *union_number, int32_t *core_num) { CNdev dev; cnCtxGetDevice(&dev); CNRT_CHECK(cnrtDeviceGetAttribute(union_number, cnrtAttrMaxClusterPerUnionLimitTask, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(core_num, cnrtAttrMcorePerCluster, dev)); } KernelStatus invokeCastGating(cnrtQueue_t queue, void *input, void *filter, void *output, int input_row, int expert_num, int hidden_size, cnnlDataType_t a_dtype, void *workspace, size_t workspace_size_bytes) { if (is_arch300()) { std::cerr << "[invokeCastGating]: kernel does not support MLU300 devices." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (expert_num > 128) { std::cerr << "[invokeCastGating]: expert_num should NOT be greater than 128." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (workspace != NULL && workspace_size_bytes < 16 * 1024 * 1024) { std::cerr << "[invokeCastGating]: workspace_size_bytes should NOT be smaller than 16 * 1024 * 1024." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } if (workspace_size_bytes > 0 && workspace == NULL) { std::cerr << "[invokeCastGating]: workspace should NOT be NULL when workspace_size_bytes is " "greater than 0." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } int32_t union_number, core_num; getContxtInfo(&union_number, &core_num); cnrtFunctionType_t func_type = cnrtFunctionType_t(union_number * core_num); cnrtDim3_t dim; dim.x = (int32_t)func_type; dim.y = 1; dim.z = 1; cnnlDataType_t b_dtype = CNNL_DTYPE_FLOAT; cnnlDataType_t compute_dtype = CNNL_DTYPE_FLOAT; size_t a_dtype_size = 0, b_dtype_size = 0, compute_dtype_size = 0; cnnlGetSizeOfDataType(a_dtype, &a_dtype_size); cnnlGetSizeOfDataType(b_dtype, &b_dtype_size); cnnlGetSizeOfDataType(compute_dtype, &compute_dtype_size); castGatingTileInfo split_info; bool EXCHANGE_AB = false; gatingTiling(input_row, expert_num, hidden_size, a_dtype_size, b_dtype_size, compute_dtype_size, workspace_size_bytes, union_number, core_num, split_info.block, split_info.split_k_num, split_info.block_k, EXCHANGE_AB); if (a_dtype == CNNL_DTYPE_BFLOAT16) { if (EXCHANGE_AB) { kernels::MLUCastGating <<>>((float *)filter, (bfloat16_t *)input, (float *)output, (float *)workspace, expert_num, input_row, hidden_size, hidden_size, hidden_size, expert_num, split_info); } else { kernels::MLUCastGating <<>>((bfloat16_t *)input, (float *)filter, (float *)output, (float *)workspace, input_row, expert_num, hidden_size, hidden_size, hidden_size, expert_num, split_info); } } else if (a_dtype == CNNL_DTYPE_HALF) { if (EXCHANGE_AB) { kernels::MLUCastGating <<>>((float *)filter, (half *)input, (float *)output, (float *)workspace, expert_num, input_row, hidden_size, hidden_size, hidden_size, expert_num, split_info); } else { kernels::MLUCastGating <<>>((half *)input, (float *)filter, (float *)output, (float *)workspace, input_row, expert_num, hidden_size, hidden_size, hidden_size, expert_num, split_info); } } else { std::cerr << "[invokeCastGating]: kernel does not support this data-type." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo