#include #include #include #include #include "cnnl.h" #include "cnrt.h" #include "embedding.mluh" // clang-format off #include // clang-format on namespace tmo { namespace kernels { #define MAX_UINT32 (4294967295) #define MAX_SINT32 (2147483647) #define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024) __nram__ int8_t nram_buffer[NRAM_SIZE]; __mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) { int base = total / num; int tail = total - base * num; every = base + (id < tail ? 1 : 0); offset = base * id + (id < tail ? id : tail); } #define PAD_DOWN(x, y) (((x) / (y)) * (y)) #define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y)) template __mlu_func__ void embeddingImpl_500(T *filter, int *input_ids, T *output, int vocab_offset, int vocab_size, int input_size, int total_seq) { if (__is_mpu()) { return; }; int bs_core = 0; int bs_offset = 0; split(total_seq, taskDim, taskId, bs_core, bs_offset); // 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex must be divisible // by 8 int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) / (input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t)); int vocab_start = vocab_offset; int vocab_end = vocab_offset + vocab_size - 1; T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T) T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T) int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) int *idxs_nram = ones_nram + limit; // limit * sizeof(int) int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int) int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int) uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t) __bang_write_zero(zeros_nram, input_size); __bang_write_zero(zeros_offset_nram, limit); __bang_write_value(ones_nram, limit, 1); int repeat = bs_core / limit; int remain = bs_core % limit; for (int i = 0; i < repeat + 1; i++) { if ((i == repeat) && (remain == 0)) { return; } int num = (i == repeat) ? remain : limit; int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); __sync(); __bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num); __bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num); __bang_mul(mask_nram, mask_nram, temp_nram, num); __bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram, num_pad); // gather valid mask __bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask __bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index __bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram, (unsigned int)input_size * sizeof(T), num); // gather offset __sync(); __gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T), GDRAM2NRAM, input_size * sizeof(T), num); __gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T), NRAM2NRAM, input_size * sizeof(T), num); __sync(); __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, num * input_size * sizeof(T), NRAM2GDRAM); __sync(); } } template __mlu_func__ void write_zero(T *dst, unsigned int elem_count) { __bang_write_zero(dst, elem_count); } template <> __mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) { #if __BANG_ARCH__ >= 500 __bang_write_zero(dst, elem_count); #endif } template __mlu_func__ void embeddingImpl_300(T *filter, int *input_ids, T *output, int vocab_offset, int vocab_size, int input_size, int total_seq) { if (__is_mpu()) { return; }; int bs_core = 0; int bs_offset = 0; split(total_seq, taskDim, taskId, bs_core, bs_offset); int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int)); limit = PAD_DOWN(limit, 2); int repeat = bs_core / limit; int remain = bs_core % limit; int vocab_start = vocab_offset; int vocab_end = vocab_offset + vocab_size - 1; T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T) int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) for (int i = 0; i < repeat + 1; i++) { if ((i == repeat) && (remain == 0)) { return; } int num = (i == repeat) ? remain : limit; __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); __sync(); int idx1 = idxs_nram[0]; int idx2 = idxs_nram[1]; bool first = (idx1 >= vocab_start && idx1 <= vocab_end); bool second = (idx2 >= vocab_start && idx2 <= vocab_end); for (int n = 0; n < num / 2 * 2; n += 2) { if (first && second) { __memcpy_async(emb_nram + n * input_size, filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T), 1); } else if (!first && !second) { write_zero(emb_nram + n * input_size, 2 * input_size); } else if (first && !second) { write_zero(emb_nram + (n + 1) * input_size, input_size); __memcpy_async(emb_nram + n * input_size, filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), GDRAM2NRAM); } else { write_zero(emb_nram + n * input_size, input_size); __memcpy_async(emb_nram + (n + 1) * input_size, filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), GDRAM2NRAM); } idx1 = idxs_nram[n + 2]; idx2 = idxs_nram[n + 3]; first = (idx1 >= vocab_start && idx1 <= vocab_end); second = (idx2 >= vocab_start && idx2 <= vocab_end); } // copy loop // last idx copy if (num % 2 == 1) { if (first) { __memcpy_async(emb_nram + (num - 1) * input_size, filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T), GDRAM2NRAM); } else { write_zero(emb_nram + (num - 1) * input_size, input_size); } } __sync(); __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, num * input_size * sizeof(T), NRAM2GDRAM); __sync(); } } template __mlu_func__ void embeddingImpl_generic(T *filter, int *input_ids, T *output, int vocab_offset, int vocab_size, int input_size, int total_seq) { if (__is_mpu()) { return; }; int bs_core = 0; int bs_offset = 0; split(total_seq, taskDim, taskId, bs_core, bs_offset); int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int)); limit = PAD_DOWN(limit, 2); int repeat = bs_core / limit; int remain = bs_core % limit; int vocab_start = vocab_offset; int vocab_end = vocab_offset + vocab_size - 1; T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T) int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int) for (int i = 0; i < repeat + 1; i++) { if ((i == repeat) && (remain == 0)) { return; } int num = (i == repeat) ? remain : limit; __memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM); __sync(); int idx = idxs_nram[0]; bool hit = (idx >= vocab_start && idx <= vocab_end); for (int n = 0; n < num; n++) { if (hit) { __memcpy_async(emb_nram + n * input_size, filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T), GDRAM2NRAM); } else { write_zero(emb_nram + n * input_size, input_size); } idx = idxs_nram[n + 1]; hit = (idx >= vocab_start && idx <= vocab_end); } __sync(); __memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram, num * input_size * sizeof(T), NRAM2GDRAM); __sync(); } } template __mlu_global__ void MLUEmbeddingKernel(T *filter, int *input_ids, T *output, int vocab_offset, int vocab_size, int total_vocab_size, int input_size, int total_seq) { #if __BANG_ARCH__ > 372 // __gather index maximum dtype is unsigned int if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) { embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); } else { embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); } #else // __memcpy 2D src_stride dtype is int if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) { embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); } else { embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq); } #endif } } // namespace kernels KernelStatus invokeEmbedding(cnrtQueue_t queue, void *filter, void *input_ids, void *output, const cnnlDataType_t dtype, int vocab_offset, int vocab_size, int total_vocab_size, int input_size, int total_seq) { CNdev dev; cnCtxGetDevice(&dev); int cluster_num; int core_num; CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev)); CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev)); cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1}; if (dtype == CNNL_DTYPE_FLOAT) { kernels::MLUEmbeddingKernel<<>>( static_cast(filter), (int *)input_ids, static_cast(output), vocab_offset, vocab_size, total_vocab_size, input_size, total_seq); } else if (dtype == CNNL_DTYPE_HALF) { kernels::MLUEmbeddingKernel<<>>( static_cast(filter), (int *)input_ids, static_cast(output), vocab_offset, vocab_size, total_vocab_size, input_size, total_seq); } else if (dtype == CNNL_DTYPE_BFLOAT16) { if (!isBf16Supported()) { std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl; return KernelStatus::KERNEL_STATUS_FAILED; } kernels::MLUEmbeddingKernel<<>>( static_cast(filter), (int *)input_ids, static_cast(output), vocab_offset, vocab_size, total_vocab_size, input_size, total_seq); } return KernelStatus::KERNEL_STATUS_SUCCESS; } } // namespace tmo