CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

View File

@@ -72,6 +72,7 @@ void rmsnorm_kernel_impl(
const scalar_t* __restrict__ weight,
int64_t batch_size,
int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
@@ -81,7 +82,7 @@ void rmsnorm_kernel_impl(
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
@@ -140,6 +141,7 @@ void fused_add_rmsnorm_kernel_impl(
float* __restrict__ buffer,
int64_t batch_size,
int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
@@ -151,7 +153,7 @@ void fused_add_rmsnorm_kernel_impl(
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ input_ptr = input + i * hidden_size;
scalar_t* __restrict__ input_ptr = input + i * input_strideN;
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
fVec sum_fvec = fVec(float(0));
@@ -242,7 +244,7 @@ at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
CHECK_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(1, weight);
@@ -250,6 +252,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
int64_t input_strideN = input.stride(0);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
rmsnorm_kernel_impl<scalar_t>(
@@ -258,6 +261,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
weight.data_ptr<scalar_t>(),
batch_size,
hidden_size,
input_strideN,
eps);
});
return output;
@@ -268,7 +272,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
// weight : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
CHECK_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
@@ -279,6 +283,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
int64_t input_strideN = input.stride(0);
// allocate temp buffer to store x in float32 per thread
// TODO: implement a singleton for context
@@ -293,6 +298,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
buffer.data_ptr<float>(),
batch_size,
hidden_size,
input_strideN,
eps);
});
}