67 lines
2.0 KiB
Plaintext
67 lines
2.0 KiB
Plaintext
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <ATen/cuda/Atomic.cuh>
|
|
#include <cub/cub.cuh>
|
|
|
|
#include "utils.h"
|
|
|
|
template <typename scalar_t, int TOPK>
|
|
__global__ void moe_sum_kernel(
|
|
scalar_t* __restrict__ out, // [..., d]
|
|
const scalar_t* __restrict__ input, // [..., topk, d]
|
|
const int d) {
|
|
const int64_t token_idx = blockIdx.x;
|
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
|
scalar_t x = 0.0;
|
|
#pragma unroll
|
|
for (int k = 0; k < TOPK; ++k) {
|
|
x += SGLANG_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
|
}
|
|
out[token_idx * d + idx] = x;
|
|
}
|
|
}
|
|
|
|
void moe_sum(
|
|
torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
|
torch::Tensor& output) // [num_tokens, hidden_size]
|
|
{
|
|
const int hidden_size = input.size(-1);
|
|
const auto num_tokens = output.numel() / hidden_size;
|
|
const int topk = input.size(1);
|
|
|
|
dim3 grid(num_tokens);
|
|
dim3 block(std::min(hidden_size, 1024));
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
switch (topk) {
|
|
case 2:
|
|
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
|
moe_sum_kernel<scalar_t, 2>
|
|
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
|
});
|
|
break;
|
|
|
|
case 3:
|
|
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
|
moe_sum_kernel<scalar_t, 3>
|
|
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
|
});
|
|
break;
|
|
|
|
case 4:
|
|
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
|
moe_sum_kernel<scalar_t, 4>
|
|
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
|
});
|
|
break;
|
|
|
|
default:
|
|
at::sum_out(output, input, 1);
|
|
break;
|
|
}
|
|
}
|