opencl: support pad_ext (#15888)
This commit is contained in:
@@ -1,30 +1,39 @@
|
||||
kernel void kernel_pad(
|
||||
global const void * src0_ptr,
|
||||
ulong src0_offset,
|
||||
global void * dst_ptr,
|
||||
ulong dst_offset,
|
||||
int s_ne0, int s_ne1, int s_ne2,
|
||||
int d_ne0, int d_ne1, int d_ne2
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global void * dst,
|
||||
ulong offsetd,
|
||||
int ne00, int ne01, int ne02, int ne03,
|
||||
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
|
||||
int lp0, int rp0,
|
||||
int lp1, int rp1,
|
||||
int lp2, int rp2,
|
||||
int lp3, int rp3
|
||||
) {
|
||||
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
|
||||
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int nidx = get_global_id(0);
|
||||
int idx_d1 = get_group_id(1);
|
||||
int idx_d2 = get_group_id(2);
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_group_id(1);
|
||||
int i2 = get_group_id(2) % ne2;
|
||||
int i3 = get_group_id(2) / ne2;
|
||||
|
||||
if (nidx >= d_ne0) {
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
|
||||
uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
|
||||
|
||||
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
|
||||
global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
|
||||
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);
|
||||
|
||||
if (in_src_bounds) {
|
||||
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
|
||||
dst[dst_el_offset] = src0[src_el_offset];
|
||||
} else {
|
||||
dst[dst_el_offset] = 0.0f;
|
||||
}
|
||||
bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3);
|
||||
|
||||
*dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user