llama : add gpt-oss (#15091)
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
58
ggml/src/ggml-cuda/add-id.cu
Normal file
58
ggml/src/ggml-cuda/add-id.cu
Normal file
@@ -0,0 +1,58 @@
|
||||
#include "add-id.cuh"
|
||||
|
||||
static __global__ void add_id_kernel(
|
||||
const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||
int64_t ne0, int64_t ne1,
|
||||
size_t nb01, size_t nb02,
|
||||
size_t nb11,
|
||||
size_t nb21
|
||||
) {
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
|
||||
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
|
||||
|
||||
const size_t nb1 = ne0 * sizeof(float);
|
||||
const size_t nb2 = ne1 * nb1;
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
||||
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
|
||||
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
const int32_t * src2_d = (const int32_t *)src2->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
int threads = std::min((int)ne00, 768); // cols
|
||||
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src0_d, src1_d, src2_d, dst_d,
|
||||
ne0, ne1,
|
||||
nb01, nb02,
|
||||
nb11,
|
||||
nb21
|
||||
);
|
||||
}
|
||||
3
ggml/src/ggml-cuda/add-id.cuh
Normal file
3
ggml/src/ggml-cuda/add-id.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -549,6 +550,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
#if CUDART_VERSION >= 12080
|
||||
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
|
||||
return (float) e;
|
||||
#else
|
||||
uint32_t bits;
|
||||
if (x == 0) {
|
||||
bits = 0x00400000;
|
||||
} else {
|
||||
bits = (uint32_t) x << 23;
|
||||
}
|
||||
|
||||
float result;
|
||||
memcpy(&result, &bits, sizeof(float));
|
||||
return result;
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
}
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
|
||||
static __device__ __forceinline__ float get_alibi_slope(
|
||||
@@ -607,6 +626,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
|
||||
static constexpr int qi = QI8_0;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
|
||||
static constexpr int qk = QK_MXFP4;
|
||||
static constexpr int qr = QR_MXFP4;
|
||||
static constexpr int qi = QI_MXFP4;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
||||
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
|
||||
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
|
||||
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cont_cuda<float>;
|
||||
case GGML_TYPE_BF16:
|
||||
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cont_cuda<half>;
|
||||
case GGML_TYPE_BF16:
|
||||
|
||||
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -736,7 +737,8 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -940,6 +942,7 @@ void launch_fattn(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *) sinks->data) : nullptr,
|
||||
KV_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
|
||||
@@ -1206,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -1267,6 +1268,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
// kb0 == k start index when in the output tile.
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
@@ -1340,7 +1342,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
|
||||
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
|
||||
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
return;
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
|
||||
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax[ncols];
|
||||
half kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -HALF_MAX_HALF;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
half kqsum[ncols] = {0.0f};
|
||||
|
||||
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
||||
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const half sink = __float2half(sinksf[head]);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
||||
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
|
||||
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
|
||||
float kqmax[ncols];
|
||||
float kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -FLT_MAX/2.0f;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
float kqsum[ncols] = {0.0f};
|
||||
|
||||
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
||||
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= KQ_max_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
|
||||
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
|
||||
@@ -269,17 +269,28 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
|
||||
if (sinks) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "ggml-cuda/common.cuh"
|
||||
#include "ggml-cuda/acc.cuh"
|
||||
#include "ggml-cuda/add-id.cuh"
|
||||
#include "ggml-cuda/arange.cuh"
|
||||
#include "ggml-cuda/argmax.cuh"
|
||||
#include "ggml-cuda/argsort.cuh"
|
||||
@@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_ADD1: // TODO: more efficient implementation
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
ggml_cuda_op_add_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUB:
|
||||
ggml_cuda_op_sub(ctx, dst);
|
||||
break;
|
||||
@@ -2333,6 +2337,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
ggml_cuda_op_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
ggml_cuda_op_swiglu_oai(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
ggml_cuda_op_geglu_erf(ctx, dst);
|
||||
break;
|
||||
@@ -2607,6 +2614,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
@@ -2629,7 +2639,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
|
||||
if (node->op == GGML_OP_ADD &&
|
||||
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
@@ -3227,6 +3243,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]);
|
||||
@@ -3277,6 +3294,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
@@ -3423,6 +3441,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -3503,6 +3522,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
||||
}
|
||||
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
||||
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#include "im2col.cuh"
|
||||
|
||||
#define MIN(a, b) (a) < (b) ? (a) : (b)
|
||||
|
||||
#define MAX_GRIDDIM_Z 65535
|
||||
|
||||
template <typename T>
|
||||
@@ -38,6 +36,9 @@ static __global__ void im2col_kernel(
|
||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(IC);
|
||||
GGML_UNUSED(KH);
|
||||
}
|
||||
|
||||
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
||||
|
||||
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||
break;
|
||||
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
||||
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
||||
case GGML_TYPE_Q3_K:
|
||||
@@ -170,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
||||
@@ -206,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
@@ -692,6 +696,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
|
||||
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
||||
constexpr int nrows = warp_size / threads_per_row;
|
||||
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
||||
const int kbx = txi / QI_MXFP4;
|
||||
const int kqsx = txi % QI_MXFP4;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
||||
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
||||
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
||||
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
||||
#else
|
||||
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
@@ -2268,7 +2337,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
||||
|
||||
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
@@ -2707,7 +2776,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
||||
|
||||
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
||||
@@ -2863,6 +2932,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
@@ -3642,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||
|
||||
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
|
||||
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
||||
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
||||
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
||||
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
||||
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
||||
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
||||
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
||||
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
||||
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
|
||||
@@ -45,7 +45,7 @@ struct soft_max_params {
|
||||
#endif // __clang__
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
||||
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
if (sinks) {
|
||||
tmp += expf(sinks[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
|
||||
}
|
||||
|
||||
template<int... Ns, typename T>
|
||||
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
||||
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
||||
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||
{
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
||||
if (p.ncols == ncols) {
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, p);
|
||||
(x, mask, sinks, dst, p);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
|
||||
|
||||
//default case
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
||||
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
||||
}
|
||||
|
||||
// swiglu_oai
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
||||
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||
swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
void * src0_d = src0->data;
|
||||
void * src1_d = src1 ? src1->data : src0->data;
|
||||
const int64_t src0_o = src0->nb[1];
|
||||
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
void * dst_d = dst->data;
|
||||
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||
GGML_ASSERT(src1->ne[0] == nc);
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||
|
||||
float * src0_p = (float *) src0_d;
|
||||
float * src1_p = (float *) src1_d;
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||
}
|
||||
|
||||
/* silu_back */
|
||||
|
||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||
|
||||
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
|
||||
const uint8_t * x8 = (const uint8_t *) x;
|
||||
|
||||
int x32 = x8[4*i32 + 0] << 0;
|
||||
x32 |= x8[4*i32 + 1] << 8;
|
||||
x32 |= x8[4*i32 + 2] << 16;
|
||||
x32 |= x8[4*i32 + 3] << 24;
|
||||
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
||||
|
||||
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
|
||||
|
||||
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
||||
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
||||
const char4 val1_8 = make_char4(
|
||||
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
|
||||
return d8_1*sumf;
|
||||
}
|
||||
|
||||
#define VDR_MXFP4_Q8_1_MMVQ 2
|
||||
#define VDR_MXFP4_Q8_1_MMQ 4
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
|
||||
|
||||
const int * q8 = (const int *) bq8_1->qs + iqs;
|
||||
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
|
||||
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||
|
||||
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
||||
}
|
||||
|
||||
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
|
||||
return d * sumi;
|
||||
}
|
||||
|
||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||
|
||||
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
||||
|
||||
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
||||
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
||||
const char4 val1_8 = make_char4(
|
||||
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
}
|
||||
|
||||
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
||||
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
||||
|
||||
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||
|
||||
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
||||
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
|
||||
const int2 v = get_int_from_table_16(aux_q4);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
||||
|
||||
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
|
||||
|
||||
4
ggml/src/ggml-cuda/vendors/cuda.h
vendored
4
ggml/src/ggml-cuda/vendors/cuda.h
vendored
@@ -6,6 +6,10 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if CUDART_VERSION >= 12050
|
||||
#include <cuda_fp8.h>
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
|
||||
Reference in New Issue
Block a user