update norm cu (#3048)
This commit is contained in:
@@ -91,7 +91,7 @@ ext_modules = [
|
|||||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||||
"src/sgl-kernel/csrc/norm.cu",
|
"3rdparty/flashinfer/csrc/norm.cu",
|
||||||
],
|
],
|
||||||
include_dirs=include_dirs,
|
include_dirs=include_dirs,
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
#include <cstdint>
|
|
||||||
#include <flashinfer/norm.cuh>
|
|
||||||
|
|
||||||
#include "pytorch_extension_utils.h"
|
|
||||||
|
|
||||||
using namespace flashinfer;
|
|
||||||
|
|
||||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) {
|
|
||||||
CHECK_INPUT(input);
|
|
||||||
CHECK_INPUT(weight);
|
|
||||||
auto device = input.device();
|
|
||||||
CHECK_EQ(weight.device(), device);
|
|
||||||
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
|
|
||||||
CHECK_DIM(1, weight); // weight: (hidden_size)
|
|
||||||
CHECK_EQ(input.size(1), weight.size(0));
|
|
||||||
unsigned int batch_size = input.size(0);
|
|
||||||
unsigned int hidden_size = input.size(1);
|
|
||||||
CHECK_EQ(output.size(0), batch_size);
|
|
||||||
CHECK_EQ(output.size(1), hidden_size);
|
|
||||||
|
|
||||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
|
||||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
|
|
||||||
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
|
|
||||||
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, eps, stream);
|
|
||||||
TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user