[7/n] decouple quantization impl from vllm dependency - gguf kernel (#11019)
This commit is contained in:
66
sgl-kernel/csrc/moe/moe_sum.cu
Normal file
66
sgl-kernel/csrc/moe/moe_sum.cu
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., topk, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
scalar_t x = 0.0;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOPK; ++k) {
|
||||
x += SGLANG_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
||||
}
|
||||
out[token_idx * d + idx] = x;
|
||||
}
|
||||
}
|
||||
|
||||
void moe_sum(
|
||||
torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
switch (topk) {
|
||||
case 2:
|
||||
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
moe_sum_kernel<scalar_t, 2>
|
||||
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 3:
|
||||
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
moe_sum_kernel<scalar_t, 3>
|
||||
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 4:
|
||||
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
moe_sum_kernel<scalar_t, 4>
|
||||
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
at::sum_out(output, input, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user