* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
689 lines
12 KiB
C
689 lines
12 KiB
C
#ifndef GGML_METAL_IMPL
|
|
#define GGML_METAL_IMPL
|
|
|
|
// kernel parameters for mat-vec threadgroups
|
|
//
|
|
// N_R0: number of src0 rows to process per simdgroup
|
|
// N_SG: number of simdgroups per threadgroup
|
|
//
|
|
// TODO: for optimal performance, become function of the device and work size
|
|
|
|
#define N_R0_Q4_0 4
|
|
#define N_SG_Q4_0 2
|
|
|
|
#define N_R0_Q4_1 4
|
|
#define N_SG_Q4_1 2
|
|
|
|
#define N_R0_Q5_0 4
|
|
#define N_SG_Q5_0 2
|
|
|
|
#define N_R0_Q5_1 4
|
|
#define N_SG_Q5_1 2
|
|
|
|
#define N_R0_Q8_0 4
|
|
#define N_SG_Q8_0 2
|
|
|
|
#define N_R0_MXFP4 2
|
|
#define N_SG_MXFP4 2
|
|
|
|
#define N_R0_Q2_K 4
|
|
#define N_SG_Q2_K 2
|
|
|
|
#define N_R0_Q3_K 2
|
|
#define N_SG_Q3_K 2
|
|
|
|
#define N_R0_Q4_K 4
|
|
#define N_SG_Q4_K 2
|
|
|
|
#define N_R0_Q5_K 2
|
|
#define N_SG_Q5_K 2
|
|
|
|
#define N_R0_Q6_K 1
|
|
#define N_SG_Q6_K 2
|
|
|
|
#define N_R0_IQ1_S 4
|
|
#define N_SG_IQ1_S 2
|
|
|
|
#define N_R0_IQ1_M 4
|
|
#define N_SG_IQ1_M 2
|
|
|
|
#define N_R0_IQ2_XXS 4
|
|
#define N_SG_IQ2_XXS 2
|
|
|
|
#define N_R0_IQ2_XS 4
|
|
#define N_SG_IQ2_XS 2
|
|
|
|
#define N_R0_IQ2_S 4
|
|
#define N_SG_IQ2_S 2
|
|
|
|
#define N_R0_IQ3_XXS 4
|
|
#define N_SG_IQ3_XXS 2
|
|
|
|
#define N_R0_IQ3_S 4
|
|
#define N_SG_IQ3_S 2
|
|
|
|
#define N_R0_IQ4_NL 2
|
|
#define N_SG_IQ4_NL 2
|
|
|
|
#define N_R0_IQ4_XS 2
|
|
#define N_SG_IQ4_XS 2
|
|
|
|
// kernel argument structs
|
|
//
|
|
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
|
// however, be careful from int overflows when using those in the kernel implementation
|
|
//
|
|
// - strides (e.g. nb00) use uint64_t
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t dim;
|
|
} ggml_metal_kargs_concat;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
uint64_t o1[8];
|
|
} ggml_metal_kargs_bin;
|
|
|
|
typedef struct {
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
size_t nb01;
|
|
size_t nb02;
|
|
size_t nb11;
|
|
size_t nb21;
|
|
} ggml_metal_kargs_add_id;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_repeat;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_cpy;
|
|
|
|
typedef struct {
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
int64_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
bool inplace;
|
|
} ggml_metal_kargs_set;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t n_past;
|
|
int32_t n_dims;
|
|
int32_t n_ctx_orig;
|
|
float freq_base;
|
|
float freq_scale;
|
|
float ext_factor;
|
|
float attn_factor;
|
|
float beta_fast;
|
|
float beta_slow;
|
|
int32_t sect_0;
|
|
int32_t sect_1;
|
|
int32_t sect_2;
|
|
int32_t sect_3;
|
|
} ggml_metal_kargs_rope;
|
|
|
|
typedef struct {
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne_12_2; // assume K and V are same shape
|
|
int32_t ne_12_3;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb23;
|
|
int32_t ne32;
|
|
int32_t ne33;
|
|
uint64_t nb31;
|
|
uint64_t nb32;
|
|
uint64_t nb33;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
int32_t n_head_log2;
|
|
float logit_softcap;
|
|
} ggml_metal_kargs_flash_attn_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mv;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
int16_t nsg;
|
|
int16_t nxpsg;
|
|
int16_t r1ptg;
|
|
} ggml_metal_kargs_mul_mv_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne10;
|
|
int32_t ne11; // n_expert_used (bcast)
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t neh11; // n_tokens
|
|
uint64_t nbh11;
|
|
int32_t ne20; // n_expert_used
|
|
uint64_t nb21;
|
|
} ggml_metal_kargs_mul_mm_id_map0;
|
|
|
|
typedef struct {
|
|
int32_t ne20; // n_expert_used
|
|
int32_t neh0;
|
|
int32_t neh1;
|
|
uint64_t nbh1;
|
|
uint64_t nbh2;
|
|
int32_t ne0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_mul_mm_id_map1;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t neh12;
|
|
uint64_t nbh10;
|
|
uint64_t nbh11;
|
|
uint64_t nbh12;
|
|
uint64_t nbh13;
|
|
int32_t neh0;
|
|
int32_t neh1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm_id;
|
|
|
|
typedef struct {
|
|
int32_t nei0;
|
|
int32_t nei1;
|
|
uint64_t nbi1;
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
uint64_t nb1;
|
|
} ggml_metal_kargs_mul_mv_id;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_norm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float eps;
|
|
int32_t nef1[3];
|
|
int32_t nef2[3];
|
|
int32_t nef3[3];
|
|
uint64_t nbf1[3];
|
|
uint64_t nbf2[3];
|
|
uint64_t nbf3[3];
|
|
} ggml_metal_kargs_rms_norm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_l2_norm;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t n_groups;
|
|
float eps;
|
|
} ggml_metal_kargs_group_norm;
|
|
|
|
typedef struct {
|
|
int32_t IC;
|
|
int32_t IL;
|
|
int32_t K;
|
|
int32_t s0;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
} ggml_metal_kargs_conv_transpose_1d;
|
|
|
|
typedef struct {
|
|
uint64_t ofs0;
|
|
uint64_t ofs1;
|
|
int32_t IW;
|
|
int32_t IH;
|
|
int32_t CHW;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int32_t d0;
|
|
int32_t d1;
|
|
int32_t N;
|
|
int32_t KH;
|
|
int32_t KW;
|
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
} ggml_metal_kargs_im2col;
|
|
|
|
typedef struct{
|
|
int32_t ne00;
|
|
uint64_t nb01;
|
|
int32_t ne10;
|
|
uint64_t nb11;
|
|
int32_t ne0;
|
|
uint64_t nb1;
|
|
int32_t i00;
|
|
int32_t i10;
|
|
float alpha;
|
|
float limit;
|
|
} ggml_metal_kargs_glu;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
int64_t ne12;
|
|
int64_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_sum_rows;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
int32_t n_head_log2;
|
|
} ggml_metal_kargs_soft_max;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int n_past;
|
|
} ggml_metal_kargs_diag_mask_inf;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_ssm_conv;
|
|
|
|
typedef struct {
|
|
int64_t d_state;
|
|
int64_t d_inner;
|
|
int64_t n_head;
|
|
int64_t n_group;
|
|
int64_t n_seq_tokens;
|
|
int64_t n_seqs;
|
|
int64_t s_off;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb31;
|
|
uint64_t nb41;
|
|
uint64_t nb42;
|
|
uint64_t nb43;
|
|
uint64_t nb51;
|
|
uint64_t nb52;
|
|
uint64_t nb53;
|
|
} ggml_metal_kargs_ssm_scan;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int64_t ne10;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_get_rows;
|
|
|
|
typedef struct {
|
|
int32_t nk0;
|
|
int32_t ne01;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_set_rows;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float sf0;
|
|
float sf1;
|
|
float sf2;
|
|
float sf3;
|
|
} ggml_metal_kargs_upscale;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_pad;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
} ggml_metal_kargs_pad_reflect_1d;
|
|
|
|
typedef struct {
|
|
uint64_t nb1;
|
|
int dim;
|
|
int max_period;
|
|
} ggml_metal_kargs_timestep_embedding;
|
|
|
|
typedef struct {
|
|
float slope;
|
|
} ggml_metal_kargs_leaky_relu;
|
|
|
|
typedef struct {
|
|
int64_t ncols;
|
|
int64_t ncols_pad;
|
|
} ggml_metal_kargs_argsort;
|
|
|
|
typedef struct {
|
|
int64_t ne0;
|
|
float start;
|
|
float step;
|
|
} ggml_metal_kargs_arange;
|
|
|
|
typedef struct {
|
|
int32_t k0;
|
|
int32_t k1;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int64_t IH;
|
|
int64_t IW;
|
|
int64_t OH;
|
|
int64_t OW;
|
|
int64_t parallel_elements;
|
|
} ggml_metal_kargs_pool_2d;
|
|
|
|
#endif // GGML_METAL_IMPL
|