* implement set_rows with i32 index * template fix * test quantized path warnings-- * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * forgotten name change * deduplicate cuda/sycl and test-fix * indent++ * vulkan: support set_rows with i32 index type (#16162) * disable i32 index for webgpu for now --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
190 lines
5.0 KiB
Common Lisp
190 lines
5.0 KiB
Common Lisp
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
|
|
kernel void kernel_set_rows_f32_i64(
|
|
global char * src0,
|
|
ulong offset0,
|
|
global char * src1,
|
|
ulong offset1,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne01,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne11,
|
|
int ne12,
|
|
ulong nb10,
|
|
ulong nb11,
|
|
ulong nb12,
|
|
int nblk0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
src1 = src1 + offset1;
|
|
dst = dst + offsetd;
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
|
|
if (i01 >= ne01) {
|
|
return;
|
|
}
|
|
|
|
int i12 = i03%ne12;
|
|
int i11 = i02%ne11;
|
|
|
|
int i10 = i01;
|
|
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
|
|
global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
|
|
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
dst_row[ind] = (float)src_row[ind];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_set_rows_f16_i64(
|
|
global char * src0,
|
|
ulong offset0,
|
|
global char * src1,
|
|
ulong offset1,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne01,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne11,
|
|
int ne12,
|
|
ulong nb10,
|
|
ulong nb11,
|
|
ulong nb12,
|
|
int nblk0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
src1 = src1 + offset1;
|
|
dst = dst + offsetd;
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
|
|
if (i01 >= ne01) {
|
|
return;
|
|
}
|
|
|
|
int i12 = i03%ne12;
|
|
int i11 = i02%ne11;
|
|
|
|
int i10 = i01;
|
|
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
|
|
global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
|
|
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
dst_row[ind] = src_row[ind];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_set_rows_f32_i32(
|
|
global char * src0,
|
|
ulong offset0,
|
|
global char * src1,
|
|
ulong offset1,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne01,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne11,
|
|
int ne12,
|
|
ulong nb10,
|
|
ulong nb11,
|
|
ulong nb12,
|
|
int nblk0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
src1 = src1 + offset1;
|
|
dst = dst + offsetd;
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
|
|
if (i01 >= ne01) {
|
|
return;
|
|
}
|
|
|
|
int i12 = i03%ne12;
|
|
int i11 = i02%ne11;
|
|
|
|
int i10 = i01;
|
|
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
|
|
global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
|
|
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
dst_row[ind] = (float)src_row[ind];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_set_rows_f16_i32(
|
|
global char * src0,
|
|
ulong offset0,
|
|
global char * src1,
|
|
ulong offset1,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne01,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne11,
|
|
int ne12,
|
|
ulong nb10,
|
|
ulong nb11,
|
|
ulong nb12,
|
|
int nblk0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
src1 = src1 + offset1;
|
|
dst = dst + offsetd;
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
|
|
if (i01 >= ne01) {
|
|
return;
|
|
}
|
|
|
|
int i12 = i03%ne12;
|
|
int i11 = i02%ne11;
|
|
|
|
int i10 = i01;
|
|
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
|
|
global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
|
|
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
dst_row[ind] = src_row[ind];
|
|
}
|
|
}
|