[SYCL] fix UT fault cases: count-equal, argsort, pad OPs (#16521)
* fix/refactor OP argsort, pad * fix count-equal op * update SYCL OP list * fix format issue --------- Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
This commit is contained in:
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
||||
item_ct1.get_group(0) * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = static_cast<T>(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
||||
queue_ptr stream) {
|
||||
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
float min_val;
|
||||
float max_val;
|
||||
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
|
||||
Reference in New Issue
Block a user