ggml: add ops for WAN video model (cuda && cpu) (#15669)
* add conv3d support * add ggml_pad_ext for cpu & cuda backend * cuda/cpu: add im2col_3d support * cuda: make im2col a little faster * fix cuda pad/scale/im2col3d * make im2col_3d faster * gguf: support loading tensors which n_dims > GGML_MAX_DIMS * fix cuda get_rows * avoid ggml_conv_3d conflict * correct GGML_OP_COUNT assertion * avoid build failure * avoid build failure on MacOS * cuda: remove unnecessary MIN define * fix cpu im2col_3d * adjust the code style * cuda: use simpler loop in get_rows * add test_im2col_3d to test-backend-ops * test-backend-ops.cpp: remove trailing whitespace * cpu: im2col_3d support non continuous src Co-authored-by: Jeff Bolz <jbolz@nvidia.com> * fix test_im2col_3d * remove unused variables * cuda: get_rows: dfloat2 -> float2 * add test_pad_ext to test-backend-ops.cpp * add gguf_init_from_file_ext impl * Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS" This reverts commit d8377a0a37f314bd3713fe043b4333ad661610c1. * Revert "add gguf_init_from_file_ext impl" This reverts commit d9f1d13208c68ef83b3538201ac7f31614fb1994. * update ggml_backend_vk_device_supports_op * fix ggml_backend_vk_device_supports_op * update other backend supports op for ggml_pad_ext * metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
@@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
template <typename T>
|
||||
static __global__ void im2col_3d_kernel(
|
||||
const float * src, T * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
|
||||
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
|
||||
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
|
||||
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t iic = i / KD_KH_KW;
|
||||
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
|
||||
const int64_t ikw = i % KW;
|
||||
|
||||
const int64_t iow = blockIdx.y;
|
||||
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
|
||||
const int64_t in = iz / OD_OH;
|
||||
const int64_t iod = (iz - in*OD_OH) / OH;
|
||||
const int64_t ioh = iz % OH;
|
||||
|
||||
const int64_t iiw = iow * s0 + ikw * d0 - p0;
|
||||
const int64_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const int64_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
|
||||
dst[offset_dst] = src[offset_src];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
template <typename T>
|
||||
static void im2col_3d_cuda(const float * src, T* dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
const int64_t OH_OW = OH*OW;
|
||||
const int64_t KD_KH_KW = KD*KH*KW;
|
||||
const int64_t ID_IH_IW = ID*IH*IW;
|
||||
const int64_t KH_KW = KH*KW;
|
||||
const int64_t IH_IW = IH*IW;
|
||||
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
||||
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
|
||||
const int64_t N_OD_OH = N*OD*OH;
|
||||
const int64_t OD_OH = OD*OH;
|
||||
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
|
||||
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
|
||||
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
|
||||
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
|
||||
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
|
||||
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
|
||||
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
|
||||
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f16(const float * src, half * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f32(const float * src, float * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
||||
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
||||
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
||||
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
||||
|
||||
const int64_t N = ne13 / IC;
|
||||
const int64_t ID = ne12;
|
||||
const int64_t IH = ne11;
|
||||
const int64_t IW = ne10;
|
||||
|
||||
const int64_t OC = ne03 / IC;
|
||||
const int64_t KD = ne02;
|
||||
const int64_t KH = ne01;
|
||||
const int64_t KW = ne00;
|
||||
|
||||
const int64_t OD = ne3 / N;
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
} else {
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user