311 lines
12 KiB
Plaintext
311 lines
12 KiB
Plaintext
#include <cassert>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <ostream>
|
|
#include "cnnl.h"
|
|
#include "cnrt.h"
|
|
#include "embedding.mluh"
|
|
// clang-format off
|
|
#include <mlu.h>
|
|
// clang-format on
|
|
|
|
namespace tmo {
|
|
|
|
namespace kernels {
|
|
#define MAX_UINT32 (4294967295)
|
|
#define MAX_SINT32 (2147483647)
|
|
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
|
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
|
|
|
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
|
|
int base = total / num;
|
|
int tail = total - base * num;
|
|
every = base + (id < tail ? 1 : 0);
|
|
offset = base * id + (id < tail ? id : tail);
|
|
}
|
|
|
|
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
|
|
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
|
|
|
template <typename T>
|
|
__mlu_func__ void embeddingImpl_500(T *filter,
|
|
int *input_ids,
|
|
T *output,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int input_size,
|
|
int total_seq) {
|
|
if (__is_mpu()) {
|
|
return;
|
|
};
|
|
|
|
int bs_core = 0;
|
|
int bs_offset = 0;
|
|
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
|
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
|
|
// by 8
|
|
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
|
|
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
|
|
|
|
int vocab_start = vocab_offset;
|
|
int vocab_end = vocab_offset + vocab_size - 1;
|
|
|
|
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
|
|
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
|
|
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
|
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
|
|
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
|
|
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
|
|
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
|
|
__bang_write_zero(zeros_nram, input_size);
|
|
__bang_write_zero(zeros_offset_nram, limit);
|
|
__bang_write_value(ones_nram, limit, 1);
|
|
|
|
int repeat = bs_core / limit;
|
|
int remain = bs_core % limit;
|
|
|
|
for (int i = 0; i < repeat + 1; i++) {
|
|
if ((i == repeat) && (remain == 0)) {
|
|
return;
|
|
}
|
|
int num = (i == repeat) ? remain : limit;
|
|
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
|
|
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
|
__sync();
|
|
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
|
|
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
|
|
__bang_mul(mask_nram, mask_nram, temp_nram, num);
|
|
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
|
|
num_pad); // gather valid mask
|
|
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
|
|
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
|
|
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
|
|
(unsigned int)input_size * sizeof(T), num); // gather offset
|
|
__sync();
|
|
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
|
|
GDRAM2NRAM, input_size * sizeof(T), num);
|
|
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
|
|
NRAM2NRAM, input_size * sizeof(T), num);
|
|
__sync();
|
|
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
|
num * input_size * sizeof(T), NRAM2GDRAM);
|
|
__sync();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
|
|
__bang_write_zero(dst, elem_count);
|
|
}
|
|
|
|
template <>
|
|
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
|
|
#if __BANG_ARCH__ >= 500
|
|
__bang_write_zero(dst, elem_count);
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void embeddingImpl_300(T *filter,
|
|
int *input_ids,
|
|
T *output,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int input_size,
|
|
int total_seq) {
|
|
if (__is_mpu()) {
|
|
return;
|
|
};
|
|
|
|
int bs_core = 0;
|
|
int bs_offset = 0;
|
|
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
|
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
|
limit = PAD_DOWN(limit, 2);
|
|
int repeat = bs_core / limit;
|
|
int remain = bs_core % limit;
|
|
int vocab_start = vocab_offset;
|
|
int vocab_end = vocab_offset + vocab_size - 1;
|
|
|
|
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
|
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
|
|
|
for (int i = 0; i < repeat + 1; i++) {
|
|
if ((i == repeat) && (remain == 0)) {
|
|
return;
|
|
}
|
|
int num = (i == repeat) ? remain : limit;
|
|
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
|
__sync();
|
|
|
|
int idx1 = idxs_nram[0];
|
|
int idx2 = idxs_nram[1];
|
|
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
|
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
|
for (int n = 0; n < num / 2 * 2; n += 2) {
|
|
if (first && second) {
|
|
__memcpy_async(emb_nram + n * input_size,
|
|
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
|
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
|
|
1);
|
|
} else if (!first && !second) {
|
|
write_zero(emb_nram + n * input_size, 2 * input_size);
|
|
} else if (first && !second) {
|
|
write_zero(emb_nram + (n + 1) * input_size, input_size);
|
|
__memcpy_async(emb_nram + n * input_size,
|
|
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
|
GDRAM2NRAM);
|
|
} else {
|
|
write_zero(emb_nram + n * input_size, input_size);
|
|
__memcpy_async(emb_nram + (n + 1) * input_size,
|
|
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
|
GDRAM2NRAM);
|
|
}
|
|
idx1 = idxs_nram[n + 2];
|
|
idx2 = idxs_nram[n + 3];
|
|
first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
|
second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
|
} // copy loop
|
|
|
|
// last idx copy
|
|
if (num % 2 == 1) {
|
|
if (first) {
|
|
__memcpy_async(emb_nram + (num - 1) * input_size,
|
|
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
|
GDRAM2NRAM);
|
|
} else {
|
|
write_zero(emb_nram + (num - 1) * input_size, input_size);
|
|
}
|
|
}
|
|
__sync();
|
|
|
|
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
|
num * input_size * sizeof(T), NRAM2GDRAM);
|
|
__sync();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_func__ void embeddingImpl_generic(T *filter,
|
|
int *input_ids,
|
|
T *output,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int input_size,
|
|
int total_seq) {
|
|
if (__is_mpu()) {
|
|
return;
|
|
};
|
|
|
|
int bs_core = 0;
|
|
int bs_offset = 0;
|
|
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
|
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
|
limit = PAD_DOWN(limit, 2);
|
|
int repeat = bs_core / limit;
|
|
int remain = bs_core % limit;
|
|
int vocab_start = vocab_offset;
|
|
int vocab_end = vocab_offset + vocab_size - 1;
|
|
|
|
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
|
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
|
|
|
for (int i = 0; i < repeat + 1; i++) {
|
|
if ((i == repeat) && (remain == 0)) {
|
|
return;
|
|
}
|
|
int num = (i == repeat) ? remain : limit;
|
|
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
|
__sync();
|
|
|
|
int idx = idxs_nram[0];
|
|
bool hit = (idx >= vocab_start && idx <= vocab_end);
|
|
for (int n = 0; n < num; n++) {
|
|
if (hit) {
|
|
__memcpy_async(emb_nram + n * input_size,
|
|
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
|
GDRAM2NRAM);
|
|
} else {
|
|
write_zero(emb_nram + n * input_size, input_size);
|
|
}
|
|
idx = idxs_nram[n + 1];
|
|
hit = (idx >= vocab_start && idx <= vocab_end);
|
|
}
|
|
__sync();
|
|
|
|
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
|
num * input_size * sizeof(T), NRAM2GDRAM);
|
|
__sync();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__mlu_global__ void MLUEmbeddingKernel(T *filter,
|
|
int *input_ids,
|
|
T *output,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int total_vocab_size,
|
|
int input_size,
|
|
int total_seq) {
|
|
#if __BANG_ARCH__ > 372
|
|
// __gather index maximum dtype is unsigned int
|
|
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
|
|
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
|
} else {
|
|
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
|
total_seq);
|
|
}
|
|
#else
|
|
// __memcpy 2D src_stride dtype is int
|
|
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
|
|
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
|
} else {
|
|
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
|
total_seq);
|
|
}
|
|
#endif
|
|
}
|
|
} // namespace kernels
|
|
|
|
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
|
void *filter,
|
|
void *input_ids,
|
|
void *output,
|
|
const cnnlDataType_t dtype,
|
|
int vocab_offset,
|
|
int vocab_size,
|
|
int total_vocab_size,
|
|
int input_size,
|
|
int total_seq) {
|
|
CNdev dev;
|
|
cnCtxGetDevice(&dev);
|
|
int cluster_num;
|
|
int core_num;
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
|
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
|
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
|
|
|
if (dtype == CNNL_DTYPE_FLOAT) {
|
|
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
|
|
vocab_size, total_vocab_size, input_size, total_seq);
|
|
} else if (dtype == CNNL_DTYPE_HALF) {
|
|
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
|
|
vocab_size, total_vocab_size, input_size, total_seq);
|
|
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
|
if (!isBf16Supported()) {
|
|
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
|
|
return KernelStatus::KERNEL_STATUS_FAILED;
|
|
}
|
|
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
|
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
|
|
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
|
|
}
|
|
|
|
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
|
}
|
|
|
|
} // namespace tmo
|