Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/embedding.mlu
2026-02-04 17:39:32 +08:00

311 lines
12 KiB
Plaintext

#include <cassert>
#include <iostream>
#include <map>
#include <ostream>
#include "cnnl.h"
#include "cnrt.h"
#include "embedding.mluh"
// clang-format off
#include <mlu.h>
// clang-format on
namespace tmo {
namespace kernels {
#define MAX_UINT32 (4294967295)
#define MAX_SINT32 (2147483647)
#define NRAM_SIZE (__MLU_NRAM_SIZE__ * 1024 - 32 * 1024)
__nram__ int8_t nram_buffer[NRAM_SIZE];
__mlu_func__ void split(const int total, const int num, const int id, int &every, int &offset) {
int base = total / num;
int tail = total - base * num;
every = base + (id < tail ? 1 : 0);
offset = base * id + (id < tail ? id : tail);
}
#define PAD_DOWN(x, y) (((x) / (y)) * (y))
#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))
template <typename T>
__mlu_func__ void embeddingImpl_500(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
// 8 * sizeof(int) left for mask_nram, because __bang_eq_bitindex <elem_count> must be divisible
// by 8
int limit = (NRAM_SIZE - input_size * sizeof(T) - 8 * sizeof(int)) /
(input_size * sizeof(T) + 4 * sizeof(int) + sizeof(int8_t));
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *zeros_nram = (T *)nram_buffer; // input_size * sizeof(T)
T *emb_nram = zeros_nram + input_size; // limit * input_size * sizeof(T)
int *ones_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
int *idxs_nram = ones_nram + limit; // limit * sizeof(int)
int *mask_nram = idxs_nram + limit; // limit_pad * sizeof(int)
int *temp_nram = mask_nram + PAD_UP(limit, 8); // limit * sizeof(int)
uint8_t *zeros_offset_nram = (uint8_t *)(temp_nram + limit); // limit * sizeof(int8_t)
__bang_write_zero(zeros_nram, input_size);
__bang_write_zero(zeros_offset_nram, limit);
__bang_write_value(ones_nram, limit, 1);
int repeat = bs_core / limit;
int remain = bs_core % limit;
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
int num_pad = PAD_UP(num, 8); // for __bang_eq_bitindex
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
__bang_ge_scalar(mask_nram, idxs_nram, vocab_start, num);
__bang_lt_scalar(temp_nram, idxs_nram, vocab_end + 1, num);
__bang_mul(mask_nram, mask_nram, temp_nram, num);
__bang_eq_bitindex((float *)mask_nram, (float *)mask_nram, (float *)ones_nram,
num_pad); // gather valid mask
__bang_bnot((int8_t *)temp_nram, (int8_t *)mask_nram, num); // gather invalid mask
__bang_sub_scalar(idxs_nram, idxs_nram, vocab_offset, num); // true index
__bang_mul_scalar((unsigned int *)idxs_nram, (unsigned int *)idxs_nram,
(unsigned int)input_size * sizeof(T), num); // gather offset
__sync();
__gather_async(emb_nram, filter, (unsigned int *)idxs_nram, mask_nram, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), num);
__gather_async(emb_nram, zeros_nram, zeros_offset_nram, temp_nram, input_size * sizeof(T),
NRAM2NRAM, input_size * sizeof(T), num);
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void write_zero(T *dst, unsigned int elem_count) {
__bang_write_zero(dst, elem_count);
}
template <>
__mlu_func__ void write_zero(bfloat16_t *dst, unsigned int elem_count) {
#if __BANG_ARCH__ >= 500
__bang_write_zero(dst, elem_count);
#endif
}
template <typename T>
__mlu_func__ void embeddingImpl_300(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx1 = idxs_nram[0];
int idx2 = idxs_nram[1];
bool first = (idx1 >= vocab_start && idx1 <= vocab_end);
bool second = (idx2 >= vocab_start && idx2 <= vocab_end);
for (int n = 0; n < num / 2 * 2; n += 2) {
if (first && second) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM, input_size * sizeof(T), (idx2 - idx1) * input_size * sizeof(T),
1);
} else if (!first && !second) {
write_zero(emb_nram + n * input_size, 2 * input_size);
} else if (first && !second) {
write_zero(emb_nram + (n + 1) * input_size, input_size);
__memcpy_async(emb_nram + n * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
__memcpy_async(emb_nram + (n + 1) * input_size,
filter + (idx2 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
}
idx1 = idxs_nram[n + 2];
idx2 = idxs_nram[n + 3];
first = (idx1 >= vocab_start && idx1 <= vocab_end);
second = (idx2 >= vocab_start && idx2 <= vocab_end);
} // copy loop
// last idx copy
if (num % 2 == 1) {
if (first) {
__memcpy_async(emb_nram + (num - 1) * input_size,
filter + (idx1 - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + (num - 1) * input_size, input_size);
}
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_func__ void embeddingImpl_generic(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int input_size,
int total_seq) {
if (__is_mpu()) {
return;
};
int bs_core = 0;
int bs_offset = 0;
split(total_seq, taskDim, taskId, bs_core, bs_offset);
int limit = (NRAM_SIZE - 64) / (input_size * sizeof(T) + sizeof(int));
limit = PAD_DOWN(limit, 2);
int repeat = bs_core / limit;
int remain = bs_core % limit;
int vocab_start = vocab_offset;
int vocab_end = vocab_offset + vocab_size - 1;
T *emb_nram = (T *)nram_buffer; // limit * input_size * sizeof(T)
int *idxs_nram = (int *)(emb_nram + (size_t)limit * input_size); // limit * sizeof(int)
for (int i = 0; i < repeat + 1; i++) {
if ((i == repeat) && (remain == 0)) {
return;
}
int num = (i == repeat) ? remain : limit;
__memcpy_async(idxs_nram, input_ids + bs_offset + i * limit, num * sizeof(int), GDRAM2NRAM);
__sync();
int idx = idxs_nram[0];
bool hit = (idx >= vocab_start && idx <= vocab_end);
for (int n = 0; n < num; n++) {
if (hit) {
__memcpy_async(emb_nram + n * input_size,
filter + (idx - vocab_offset) * (size_t)input_size, input_size * sizeof(T),
GDRAM2NRAM);
} else {
write_zero(emb_nram + n * input_size, input_size);
}
idx = idxs_nram[n + 1];
hit = (idx >= vocab_start && idx <= vocab_end);
}
__sync();
__memcpy_async(output + (size_t)(bs_offset + i * limit) * input_size, emb_nram,
num * input_size * sizeof(T), NRAM2GDRAM);
__sync();
}
}
template <typename T>
__mlu_global__ void MLUEmbeddingKernel(T *filter,
int *input_ids,
T *output,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
#if __BANG_ARCH__ > 372
// __gather index maximum dtype is unsigned int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_UINT32)) {
embeddingImpl_500(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#else
// __memcpy 2D src_stride dtype is int
if ((size_t)(total_vocab_size - 1) * input_size * sizeof(T) <= (size_t)(MAX_SINT32)) {
embeddingImpl_300(filter, input_ids, output, vocab_offset, vocab_size, input_size, total_seq);
} else {
embeddingImpl_generic(filter, input_ids, output, vocab_offset, vocab_size, input_size,
total_seq);
}
#endif
}
} // namespace kernels
KernelStatus invokeEmbedding(cnrtQueue_t queue,
void *filter,
void *input_ids,
void *output,
const cnnlDataType_t dtype,
int vocab_offset,
int vocab_size,
int total_vocab_size,
int input_size,
int total_seq) {
CNdev dev;
cnCtxGetDevice(&dev);
int cluster_num;
int core_num;
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
if (dtype == CNNL_DTYPE_FLOAT) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<float *>(filter), (int *)input_ids, static_cast<float *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_HALF) {
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<half *>(filter), (int *)input_ids, static_cast<half *>(output), vocab_offset,
vocab_size, total_vocab_size, input_size, total_seq);
} else if (dtype == CNNL_DTYPE_BFLOAT16) {
if (!isBf16Supported()) {
std::cerr << "[invokeEmbedding]: MLU300 devices do not support bfloat16." << std::endl;
return KernelStatus::KERNEL_STATUS_FAILED;
}
kernels::MLUEmbeddingKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
static_cast<bfloat16_t *>(filter), (int *)input_ids, static_cast<bfloat16_t *>(output),
vocab_offset, vocab_size, total_vocab_size, input_size, total_seq);
}
return KernelStatus::KERNEL_STATUS_SUCCESS;
}
} // namespace tmo