CUDA: add fused rms norm (#14800)

This commit is contained in:
Aman Gupta
2025-07-23 09:25:42 +08:00
committed by GitHub
parent acd6cb1c41
commit 8c988fa41d
4 changed files with 144 additions and 9 deletions

View File

@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}
template <int block_size>
template <int block_size, bool do_multiply = false>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
const int mul_row = row % mul_nrows;
const int mul_channel = channel % mul_nchannels;
const int mul_sample = sample % mul_nsamples;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
}
}
}
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
static void rms_norm_mul_f32_cuda(
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
}
}
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if(mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) mul_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass