CUDA: fix bug in rms_norm fusion (#15660)

* CUDA: fix bug in rms_norm fusion

* Fix bug for OP_REPEAT

* Fix index for add
This commit is contained in:
Aman Gupta
2025-08-29 21:30:06 +08:00
committed by GitHub
parent 60e5eee31f
commit 81017865ee
3 changed files with 51 additions and 23 deletions

View File

@@ -127,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
const int add_nrows = 0,
const int add_nchannels = 0,
const int add_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -135,6 +136,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
const int sample = blockIdx.z;
const int tid = threadIdx.x;
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -185,9 +188,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
} else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else if constexpr (do_add) {
const int add_col = col % add_ncols;
dst[col] += add[add_col];
} else {
dst[col] = scale * x[col];
}