OpenCL: add fused group_norm/norm, mul, add (#15314)
* add fused group_norm/norm, mul, add * fix spacing * revert rms_norm logic * fix trailing whitespace
This commit is contained in:
@@ -70,3 +70,52 @@ kernel void kernel_group_norm(
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// group_norm_mul_add
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_32
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_group_norm_mul_add(
|
||||
global float * src0, ulong offset0,
|
||||
global float * src1, ulong offset1,
|
||||
global float * src2, ulong offset2,
|
||||
global float * dst, ulong offsetd,
|
||||
int ne,
|
||||
int group_size,
|
||||
float eps
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global float *)((global char *)src1 + offset1);
|
||||
src2 = (global float *)((global char *)src2 + offset2);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int start = get_group_id(0) * group_size;
|
||||
int end = start + group_size;
|
||||
if (end > ne) {
|
||||
end = ne;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
float sum_sq = 0.0f;
|
||||
|
||||
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
|
||||
float val = src0[j];
|
||||
sum += val;
|
||||
sum_sq += val*val;
|
||||
}
|
||||
|
||||
sum = sub_group_reduce_add(sum);
|
||||
sum_sq = sub_group_reduce_add(sum_sq);
|
||||
|
||||
const float mean = sum / group_size;
|
||||
const float var = sum_sq / group_size - mean * mean;
|
||||
const float scale = rsqrt(var + eps);
|
||||
|
||||
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
|
||||
dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,3 +79,83 @@ kernel void kernel_norm(
|
||||
y[i00] = y[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// norm_mul_add
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_32
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_norm_mul_add(
|
||||
global char * src0_ptr, ulong src0_offset,
|
||||
global char * src1_ptr, ulong src1_offset,
|
||||
global char * src2_ptr, ulong src2_offset,
|
||||
global char * dst_ptr, ulong dst_offset,
|
||||
int ne00, int ne01, int ne02, int ne03,
|
||||
ulong nb01, ulong nb02, ulong nb03,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
ulong nb11, ulong nb12, ulong nb13,
|
||||
int ne20, int ne21, int ne22, int ne23,
|
||||
ulong nb21, ulong nb22, ulong nb23,
|
||||
ulong nbd1, ulong nbd2, ulong nbd3,
|
||||
float eps,
|
||||
local float2 * sums
|
||||
) {
|
||||
const int i03 = get_group_id(2);
|
||||
const int i02 = get_group_id(1);
|
||||
const int i01 = get_group_id(0);
|
||||
|
||||
global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
|
||||
global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
|
||||
global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
|
||||
|
||||
float p_sum = 0.0f;
|
||||
float p_sum_sq = 0.0f;
|
||||
|
||||
const int n_chunks = ne00 / 4;
|
||||
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
|
||||
float4 val = x[i00];
|
||||
p_sum += val.x + val.y + val.z + val.w;
|
||||
p_sum_sq += dot(val, val);
|
||||
}
|
||||
|
||||
p_sum = sub_group_reduce_add(p_sum);
|
||||
p_sum_sq = sub_group_reduce_add(p_sum_sq);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (get_local_id(0) == 0) {
|
||||
float sum = 0.0f;
|
||||
float sum_sq = 0.0f;
|
||||
for (uint i = 0; i < get_num_sub_groups(); ++i) {
|
||||
float2 s = sums[i];
|
||||
sum += s.x;
|
||||
sum_sq += s.y;
|
||||
}
|
||||
|
||||
const float inv_ne00 = 1.0f / (float)ne00;
|
||||
const float mean = sum * inv_ne00;
|
||||
const float variance = mad(-mean, mean, sum_sq * inv_ne00);
|
||||
|
||||
sums[0] = (float2)(mean, rsqrt(variance + eps));
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float2 mean_scale = sums[0];
|
||||
const float mean = mean_scale.x;
|
||||
const float scale = mean_scale.y;
|
||||
const float neg_mean_scale = -mean * scale;
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
|
||||
const int w_idx = ne10 > 1 ? i00 : 0;
|
||||
const int b_idx = ne20 > 1 ? i00 : 0;
|
||||
const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
|
||||
y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user