CUDA: fix bug in rms_norm fusion (#15660)
* CUDA: fix bug in rms_norm fusion * Fix bug for OP_REPEAT * Fix index for add
This commit is contained in:
@@ -127,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
||||
const int add_nrows = 0,
|
||||
const int add_nchannels = 0,
|
||||
const int add_nsamples = 0) {
|
||||
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
@@ -135,6 +136,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
@@ -185,9 +188,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
|
||||
} else if constexpr (do_multiply) {
|
||||
const int mul_col = col % mul_ncols;
|
||||
dst[col] = scale * x[col] * mul[mul_col];
|
||||
} else if constexpr (do_add) {
|
||||
const int add_col = col % add_ncols;
|
||||
dst[col] += add[add_col];
|
||||
} else {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user