diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 1aea485ff..1197611d6 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -91,7 +91,7 @@ ext_modules = [ "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", - "src/sgl-kernel/csrc/norm.cu", + "3rdparty/flashinfer/csrc/norm.cu", ], include_dirs=include_dirs, extra_compile_args={ diff --git a/sgl-kernel/src/sgl-kernel/csrc/norm.cu b/sgl-kernel/src/sgl-kernel/csrc/norm.cu deleted file mode 100644 index ad102a50d..000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/norm.cu +++ /dev/null @@ -1,28 +0,0 @@ -#include -#include - -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) { - CHECK_INPUT(input); - CHECK_INPUT(weight); - auto device = input.device(); - CHECK_EQ(weight.device(), device); - CHECK_DIM(2, input); // input: (batch_size, hidden_size) - CHECK_DIM(1, weight); // weight: (hidden_size) - CHECK_EQ(input.size(1), weight.size(0)); - unsigned int batch_size = input.size(0); - unsigned int hidden_size = input.size(1); - CHECK_EQ(output.size(0), batch_size); - CHECK_EQ(output.size(1), hidden_size); - - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), static_cast(weight.data_ptr()), - static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream); - TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status))); - return true; - }); -}