forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
219
torch_mlu_ops-v1.3.2/csrc/kernels/moe/expand_input.mlu
Normal file
@@ -0,0 +1,219 @@
|
||||
#include <bang_device_functions_extra.h>
|
||||
#include <mlu.h>
|
||||
#include "cnnl.h"
|
||||
#include "cnrt.h"
|
||||
#include "expand_input.mluh"
|
||||
|
||||
namespace tmo {
|
||||
|
||||
namespace kernels {
|
||||
|
||||
#define RESERVED_SIZE (32 * 1024)
|
||||
#define NRAM_BUFFER_SIZE (__MLU_NRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define SRAM_BUFFER_SIZE (__MLU_SRAM_SIZE__ * 1024 - RESERVED_SIZE)
|
||||
#define MEMCPY_BURST_SIZE 128
|
||||
|
||||
__mlu_shared__ int8_t sram_buffer[SRAM_BUFFER_SIZE];
|
||||
__nram__ int8_t nram_buffer[NRAM_BUFFER_SIZE];
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_func__ void ExpandInputKernel(void *output,
|
||||
void *input,
|
||||
int *index,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int num_index) {
|
||||
void *input_ptr = input;
|
||||
uint64_t input_size = data_size * num_token * hidden_size;
|
||||
// whether SRAM_BUFFER_SIZE can hold the input data
|
||||
// sram_enable is true ==> is_nram_output is true
|
||||
bool sram_enable = (hidden_size == 1) && (input_size < SRAM_BUFFER_SIZE);
|
||||
if (sram_enable) {
|
||||
input_ptr = (void *)sram_buffer;
|
||||
__memcpy(input_ptr, input, input_size, GDRAM2SRAM);
|
||||
__sync_cluster();
|
||||
}
|
||||
if (__is_mpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Each ipu core processes no less than 128B of data, and the remaining cores can idle
|
||||
int32_t max_task_num = data_size * hidden_size * num_index / MEMCPY_BURST_SIZE;
|
||||
uint32_t maxTaskDim = std::min(taskDim, std::max(max_task_num, 1));
|
||||
uint32_t total_num = num_index;
|
||||
uint32_t base = total_num / maxTaskDim;
|
||||
uint32_t tail = total_num - base * maxTaskDim;
|
||||
if (taskId >= maxTaskDim) {
|
||||
return;
|
||||
}
|
||||
uint32_t batch_per_core = base + (taskId < tail ? 1 : 0);
|
||||
uint32_t batch_step = base * taskId + (taskId < tail ? taskId : tail);
|
||||
|
||||
// nram
|
||||
/*
|
||||
* first: compute offset: index[i] * data_size
|
||||
* second: gather data
|
||||
* -------------------------------------------------
|
||||
* addr || index/offset | output |
|
||||
* type || int32_t/T_offset | T |
|
||||
* num || n | n * hidden_size |
|
||||
* -------------------------------------------------
|
||||
*/
|
||||
|
||||
uint32_t nram_size_per_pixel = sizeof(T_offset) + hidden_size * data_size;
|
||||
// whether nram can hold two pixel: if so, then GDRAM->NRAM->GDRAM, otherwise GDRAM->GDRAM
|
||||
bool is_nram_output = nram_size_per_pixel * 2 <= NRAM_BUFFER_SIZE;
|
||||
uint32_t per_num =
|
||||
is_nram_output ? NRAM_BUFFER_SIZE / nram_size_per_pixel : NRAM_BUFFER_SIZE / sizeof(T_offset);
|
||||
|
||||
int8_t *output_base = (int8_t *)output + (uint64_t)batch_step * hidden_size * data_size;
|
||||
int *index_base = index + batch_step;
|
||||
T_offset *nram_offset = (T_offset *)nram_buffer;
|
||||
int32_t *nram_index;
|
||||
if (std::is_same<T_offset, int64_t>::value) {
|
||||
nram_index = (int32_t *)nram_offset + per_num;
|
||||
} else {
|
||||
nram_index = (int32_t *)nram_offset;
|
||||
}
|
||||
int8_t *nram_output = (int8_t *)(nram_offset + per_num);
|
||||
|
||||
uint32_t repeat = batch_per_core / per_num;
|
||||
uint32_t remain = batch_per_core - repeat * per_num;
|
||||
uint32_t deal_num = per_num;
|
||||
uint32_t is_remain = remain != 0 ? 1 : 0;
|
||||
|
||||
for (int32_t i = 0; i < repeat + is_remain; i++) {
|
||||
if (i == repeat) {
|
||||
deal_num = remain;
|
||||
}
|
||||
int8_t *output_ptr = output_base + (uint64_t)i * per_num * hidden_size * data_size;
|
||||
int32_t *index_ptr = index_base + i * per_num;
|
||||
|
||||
// index -> offset
|
||||
__memcpy((void *)nram_index, (void *)index_ptr, deal_num * sizeof(int32_t), GDRAM2NRAM);
|
||||
if (std::is_same<T_offset, uint64_t>::value) {
|
||||
#if __BANG_ARCH__ > 592
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num, 0, 0);
|
||||
#else
|
||||
__bang_int322int64((int64_t *)nram_offset, (int32_t *)nram_index, deal_num);
|
||||
#endif
|
||||
}
|
||||
__bang_mul_scalar(nram_offset, nram_offset, (int64_t)data_size * hidden_size, deal_num);
|
||||
|
||||
// copy
|
||||
if (is_nram_output) {
|
||||
__bang_write_zero((int8_t *)nram_output, deal_num * hidden_size);
|
||||
mluMemcpyDirection_t dir = sram_enable ? SRAM2NRAM : GDRAM2NRAM;
|
||||
// GDRAM or SRAM -> NRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(nram_output, input_ptr, nram_offset, hidden_size * data_size, dir,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input_ptr + offset_value;
|
||||
__memcpy(nram_output + j * hidden_size * data_size, input_offset, hidden_size * data_size,
|
||||
dir);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
__memcpy(output_ptr, nram_output, deal_num * hidden_size * data_size, NRAM2GDRAM);
|
||||
} else {
|
||||
// GDRAM -> GDRAM
|
||||
#if __BANG_ARCH__ >= 592 // gather requires
|
||||
__gather(output_ptr, input, (uint64_t *)nram_offset, hidden_size * data_size, GDRAM2GDRAM,
|
||||
hidden_size * data_size, deal_num);
|
||||
#else
|
||||
for (int32_t j = 0; j < deal_num; j++) {
|
||||
T_offset offset_value = *(nram_offset + j);
|
||||
int8_t *input_offset = (int8_t *)input + offset_value;
|
||||
__memcpy(output_ptr + (T_offset)j * hidden_size * data_size, input_offset,
|
||||
hidden_size * data_size, GDRAM2GDRAM);
|
||||
}
|
||||
#endif // __BANG_ARCH__
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// T_offset: uint32_t or uint64_t
|
||||
template <size_t data_size, typename T_offset>
|
||||
__mlu_global__ void MLUExpandInputKernel(void *expand_hidden_state,
|
||||
void *hidden_state,
|
||||
int *gather_idx,
|
||||
int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
int32_t num_index = num_token * topk;
|
||||
int *gather_start_idx = (int *)gather_idx;
|
||||
if (cusum_token_count != nullptr) {
|
||||
num_index = *((int *)cusum_token_count + start_expert_id + expert_count) -
|
||||
*((int *)cusum_token_count + start_expert_id);
|
||||
gather_start_idx = (int *)gather_idx + *(cusum_token_count + start_expert_id);
|
||||
}
|
||||
ExpandInputKernel<data_size, T_offset>(expand_hidden_state, hidden_state, gather_start_idx,
|
||||
num_token, hidden_size, num_index);
|
||||
}
|
||||
// instantiate kernels
|
||||
#define INSTANTIATE_ONE(data_size, T_offset) \
|
||||
template __mlu_global__ void MLUExpandInputKernel<data_size, T_offset>( \
|
||||
void *, void *, int *, int *, int, int, int, int, int, int);
|
||||
|
||||
INSTANTIATE_ONE(1, uint32_t)
|
||||
INSTANTIATE_ONE(2, uint32_t)
|
||||
INSTANTIATE_ONE(4, uint32_t)
|
||||
INSTANTIATE_ONE(8, uint32_t)
|
||||
// large tensor
|
||||
INSTANTIATE_ONE(1, uint64_t)
|
||||
INSTANTIATE_ONE(2, uint64_t)
|
||||
INSTANTIATE_ONE(4, uint64_t)
|
||||
INSTANTIATE_ONE(8, uint64_t)
|
||||
} // namespace kernels
|
||||
|
||||
KernelStatus invokeMoeExpandInputKernel(cnrtQueue_t queue,
|
||||
void *expand_hidden_state,
|
||||
const void *hidden_state,
|
||||
const int *gather_idx,
|
||||
const int *cusum_token_count,
|
||||
int num_token,
|
||||
int hidden_size,
|
||||
int topk,
|
||||
cnnlDataType_t data_type,
|
||||
int total_expert_num,
|
||||
int start_expert_id,
|
||||
int expert_count) {
|
||||
CNdev dev;
|
||||
cnCtxGetDevice(&dev);
|
||||
int cluster_num;
|
||||
int core_num;
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&cluster_num, cnrtAttrClusterCount, dev));
|
||||
CNRT_CHECK(cnrtDeviceGetAttribute(&core_num, cnrtAttrMcorePerCluster, dev));
|
||||
|
||||
size_t type_size = 0;
|
||||
cnnlGetSizeOfDataType(data_type, &type_size);
|
||||
|
||||
int max_cluster_num =
|
||||
(uint64_t)hidden_size * num_token * topk * type_size / (core_num * MEMCPY_BURST_SIZE);
|
||||
cluster_num = std::min(std::max(max_cluster_num, 1), cluster_num);
|
||||
|
||||
cnrtDim3_t dim{.x = (uint32_t)core_num, .y = (uint32_t)cluster_num, .z = 1};
|
||||
|
||||
void (*expand_input_kernels[])(void *, void *, int *, int *, int, int, int, int, int, int) = {
|
||||
kernels::MLUExpandInputKernel<1, uint32_t>, kernels::MLUExpandInputKernel<2, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint32_t>, kernels::MLUExpandInputKernel<8, uint32_t>,
|
||||
kernels::MLUExpandInputKernel<1, uint64_t>, kernels::MLUExpandInputKernel<2, uint64_t>,
|
||||
kernels::MLUExpandInputKernel<4, uint64_t>, kernels::MLUExpandInputKernel<8, uint64_t>};
|
||||
|
||||
bool is_large_tensor = type_size * hidden_size * num_token * topk > INT32_MAX;
|
||||
int kernel_index = (type_size == 8 ? 3 : type_size >> 1) + (is_large_tensor ? 4 : 0);
|
||||
expand_input_kernels[kernel_index]<<<dim, cnrtFuncTypeUnion1, queue>>>(
|
||||
(void *)expand_hidden_state, (void *)hidden_state, (int *)gather_idx,
|
||||
(int *)cusum_token_count, num_token, hidden_size, topk, total_expert_num, start_expert_id,
|
||||
expert_count);
|
||||
return KernelStatus::KERNEL_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user