CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -72,6 +72,7 @@ void rmsnorm_kernel_impl(
|
||||
const scalar_t* __restrict__ weight,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
int64_t input_strideN,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
@@ -81,7 +82,7 @@ void rmsnorm_kernel_impl(
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
@@ -140,6 +141,7 @@ void fused_add_rmsnorm_kernel_impl(
|
||||
float* __restrict__ buffer,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
int64_t input_strideN,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
@@ -151,7 +153,7 @@ void fused_add_rmsnorm_kernel_impl(
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
||||
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
@@ -242,7 +244,7 @@ at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
CHECK_DIM(1, weight);
|
||||
@@ -250,6 +252,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
at::Tensor output = at::empty_like(input);
|
||||
int64_t input_strideN = input.stride(0);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
|
||||
rmsnorm_kernel_impl<scalar_t>(
|
||||
@@ -258,6 +261,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
weight.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input_strideN,
|
||||
eps);
|
||||
});
|
||||
return output;
|
||||
@@ -268,7 +272,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
// weight : {hidden_size}
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
||||
CHECK_INPUT(input);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
@@ -279,6 +283,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
int64_t input_strideN = input.stride(0);
|
||||
|
||||
// allocate temp buffer to store x in float32 per thread
|
||||
// TODO: implement a singleton for context
|
||||
@@ -293,6 +298,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
|
||||
buffer.data_ptr<float>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input_strideN,
|
||||
eps);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user