forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
310
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu
Normal file
310
torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu
Normal file
@@ -0,0 +1,310 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "embedding.mluh"
|
||||
// clang-format off
|
||||
#include <mlu.h>
|
||||
// clang-format on
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
#define MAX_UINT32 (4294967295)
|
||||
#define MAX_SINT32 (2147483647)
|
||||
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
|
||||
__nram__ int8_t nram_buffer[NRAM_SIZE];
|
||||
|
||||
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
|
||||
int base = total / num;
|
||||
int tail = total - base * num;
|
||||
every = base + (id < tail ? 1 : 0);
|
||||
offset = base * id + (id < tail ? id : tail);
|
||||
}
|
||||
|
||||
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
|
||||
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_500(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
|
||||
// by 8
|
||||
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
|
||||
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
|
||||
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
|
||||
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
|
||||
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
|
||||
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
|
||||
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
|
||||
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
|
||||
__bang_write_zero(zeros_nram, input_size);
|
||||
__bang_write_zero(zeros_offset_nram, limit);
|
||||
__bang_write_value(ones_nram, limit, 1);
|
||||
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
|
||||
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
|
||||
__bang_mul(mask_nram, mask_nram, temp_nram, num);
|
||||
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
|
||||
num_pad); // gather valid mask
|
||||
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
|
||||
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
|
||||
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
|
||||
(unsigned int)input_size * sizeof(T), num); // gather offset
|
||||
__sync();
|
||||
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
|
||||
GDRAM2NRAM, input_size * sizeof(T), num);
|
||||
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
|
||||
NRAM2NRAM, input_size * sizeof(T), num);
|
||||
__sync();
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
|
||||
__bang_write_zero(dst, elem_count);
|
||||
}
|
||||
|
||||
template <>
|
||||
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
|
||||
#if __BANG_ARCH__ >= 500
|
||||
__bang_write_zero(dst, elem_count);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_300(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
||||
limit = PAD_DOWN(limit, 2);
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
||||
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
|
||||
int idx1 = idxs_nram[0];
|
||||
int idx2 = idxs_nram[1];
|
||||
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
||||
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
||||
for (int n = 0; n < num / 2 * 2; n += 2) {
|
||||
if (first && second) {
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
|
||||
1);
|
||||
} else if (!first && !second) {
|
||||
write_zero(emb_nram + n * input_size, 2 * input_size);
|
||||
} else if (first && !second) {
|
||||
write_zero(emb_nram + (n + 1) * input_size, input_size);
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + n * input_size, input_size);
|
||||
__memcpy_async(emb_nram + (n + 1) * input_size,
|
||||
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
idx1 = idxs_nram[n + 2];
|
||||
idx2 = idxs_nram[n + 3];
|
||||
first = (idx1 >= vocab_start && idx1 <= vocab_end);
|
||||
second = (idx2 >= vocab_start && idx2 <= vocab_end);
|
||||
} // copy loop
|
||||
|
||||
// last idx copy
|
||||
if (num % 2 == 1) {
|
||||
if (first) {
|
||||
__memcpy_async(emb_nram + (num - 1) * input_size,
|
||||
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + (num - 1) * input_size, input_size);
|
||||
}
|
||||
}
|
||||
__sync();
|
||||
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void embeddingImpl_generic(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
};
|
||||
|
||||
int bs_core = 0;
|
||||
int bs_offset = 0;
|
||||
split(total_seq, taskDim, taskId, bs_core, bs_offset);
|
||||
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
|
||||
limit = PAD_DOWN(limit, 2);
|
||||
int repeat = bs_core / limit;
|
||||
int remain = bs_core % limit;
|
||||
int vocab_start = vocab_offset;
|
||||
int vocab_end = vocab_offset + vocab_size - 1;
|
||||
|
||||
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
|
||||
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
|
||||
|
||||
for (int i = 0; i < repeat + 1; i++) {
|
||||
if ((i == repeat) && (remain == 0)) {
|
||||
return;
|
||||
}
|
||||
int num = (i == repeat) ? remain : limit;
|
||||
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
|
||||
__sync();
|
||||
|
||||
int idx = idxs_nram[0];
|
||||
bool hit = (idx >= vocab_start && idx <= vocab_end);
|
||||
for (int n = 0; n < num; n++) {
|
||||
if (hit) {
|
||||
__memcpy_async(emb_nram + n * input_size,
|
||||
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
} else {
|
||||
write_zero(emb_nram + n * input_size, input_size);
|
||||
}
|
||||
idx = idxs_nram[n + 1];
|
||||
hit = (idx >= vocab_start && idx <= vocab_end);
|
||||
}
|
||||
__sync();
|
||||
|
||||
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
|
||||
num * input_size * sizeof(T), NRAM2GDRAM);
|
||||
__sync();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_global__ void MLUEmbeddingKernel(T *filter,
|
||||
int *input_ids,
|
||||
T *output,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
#if __BANG_ARCH__ > 372
|
||||
// __gather index maximum dtype is unsigned int
|
||||
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
|
||||
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
||||
} else {
|
||||
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
||||
total_seq);
|
||||
}
|
||||
#else
|
||||
// __memcpy 2D src_stride dtype is int
|
||||
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
|
||||
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
|
||||
} else {
|
||||
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
|
||||
total_seq);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeEmbedding(cnrtQueue_t queue,
|
||||
void *filter,
|
||||
void *input_ids,
|
||||
void *output,
|
||||
const cnnlDataType_t dtype,
|
||||
int vocab_offset,
|
||||
int vocab_size,
|
||||
int total_vocab_size,
|
||||
int input_size,
|
||||
int total_seq) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
if (dtype == CNNL_DTYPE_FLOAT) {
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
|
||||
vocab_size, total_vocab_size, input_size, total_seq);
|
||||
} else if (dtype == CNNL_DTYPE_HALF) {
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
|
||||
vocab_size, total_vocab_size, input_size, total_seq);
|
||||
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
|
||||
if (!isBf16Supported()) {
|
||||
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
|
||||
return KernelStatus::KERNEL_STATUS_FAILED;
|
||||
}
|
||||
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
|
||||
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
|
||||
}
|
||||
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user