202 lines
6.3 KiB
Common Lisp
202 lines
6.3 KiB
Common Lisp
|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||
|
|
|
||
|
|
#define GELU_COEF_A 0.044715f
|
||
|
|
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||
|
|
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
// geglu
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
kernel void kernel_geglu(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const float x0 = src0_row[i0];
|
||
|
|
const float x1 = src1_row[i0];
|
||
|
|
|
||
|
|
const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||
|
|
|
||
|
|
dst_row[i0] = gelu*x1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
kernel void kernel_geglu_f16(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const half x0 = src0_row[i0];
|
||
|
|
const half x1 = src1_row[i0];
|
||
|
|
|
||
|
|
const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||
|
|
|
||
|
|
dst_row[i0] = gelu*x1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
// reglu
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
kernel void kernel_reglu(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const float x0 = src0_row[i0];
|
||
|
|
const float x1 = src1_row[i0];
|
||
|
|
|
||
|
|
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
kernel void kernel_reglu_f16(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const half x0 = src0_row[i0];
|
||
|
|
const half x1 = src1_row[i0];
|
||
|
|
|
||
|
|
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
// swiglu
|
||
|
|
//------------------------------------------------------------------------------
|
||
|
|
kernel void kernel_swiglu(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const float x0 = src0_row[i0];
|
||
|
|
const float x1 = src1_row[i0];
|
||
|
|
|
||
|
|
const float silu = x0 / (1.0f + exp(-x0));
|
||
|
|
|
||
|
|
dst_row[i0] = silu*x1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
kernel void kernel_swiglu_f16(
|
||
|
|
global char * src0,
|
||
|
|
ulong offset0,
|
||
|
|
global char * src1,
|
||
|
|
ulong offset1,
|
||
|
|
global char * dst,
|
||
|
|
ulong offsetd,
|
||
|
|
ulong nb01,
|
||
|
|
ulong nb11,
|
||
|
|
int ne0,
|
||
|
|
ulong nb1,
|
||
|
|
int ne00_off,
|
||
|
|
int ne10_off
|
||
|
|
) {
|
||
|
|
src0 = (global char*)((global char*)src0 + offset0);
|
||
|
|
src1 = (global char*)((global char*)src1 + offset1);
|
||
|
|
dst = (global char*)((global char*)dst + offsetd);
|
||
|
|
|
||
|
|
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||
|
|
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||
|
|
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||
|
|
|
||
|
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||
|
|
const half x0 = src0_row[i0];
|
||
|
|
const half x1 = src1_row[i0];
|
||
|
|
|
||
|
|
const half silu = x0 / (1.0f + exp(-x0));
|
||
|
|
|
||
|
|
dst_row[i0] = silu*x1;
|
||
|
|
}
|
||
|
|
}
|