Sync from upstream llama.cpp repository

This commit is contained in:
2026-01-16 10:43:34 +08:00
parent 3bc369a6f7
commit f4ae4cc7da
2053 changed files with 956010 additions and 1 deletions

View File

@@ -0,0 +1,259 @@
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "CUDA Toolkit found")
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# native == GPUs available at build time
# 50 == Maxwell, lowest CUDA 12 standard
# 60 == P100, FP16 CUDA intrinsics
# 61 == Pascal, __dp4a instruction (per-byte integer dot product)
# 70 == V100, FP16 tensor cores
# 75 == Turing, int8 tensor cores
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture
# no suffix == compile as both PTX and device code
#
# The default behavior for a non-native is to build virtual architectures as needed to cover all features needed
# for best performance and to also build real architectures for the most commonly used GPUs.
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
set(CMAKE_CUDA_ARCHITECTURES "native")
else()
if (CUDAToolkit_VERSION VERSION_LESS "13")
list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual)
endif ()
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# The CUDA architecture 120f-virtual would in principle work for Blackwell support
# but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.
# So either a recent CMake version or one with the backported fix is needed.
# The following versions should work:
# - CMake >= v3.31.8 && CMake < v4.0.0
# - CMake >= v4.0.2
# This is NOT documented in the CMake release notes,
# check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
# However, the architectures 120a-real and 121a-real should work with basically any CMake version and
# until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
endif()
endif()
endif()
enable_language(CUDA)
# TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
if (GGML_CUDA_CUB_3DOT2)
include(FetchContent)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG v3.2.0-rc2
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(CCCL)
endif()
# Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
# 12X is forwards-compatible, 12Xa is not.
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
# But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
# So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
set(FIXED_ARCHS "")
foreach(ARCH IN LISTS ${ARCHS})
if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH})
message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}")
list(APPEND FIXED_ARCHS "${FIXED_ARCH}")
else()
list(APPEND FIXED_ARCHS "${ARCH}")
endif()
endforeach()
set(${ARCHS} ${FIXED_ARCHS})
endforeach()
# If we try to compile a "native" build it will use the 12X architectures and fail.
# So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
# But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.
if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$")
set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
endif()
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}")
file(GLOB GGML_HEADERS_CUDA "*.cuh")
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
ggml_add_backend_library(ggml-cuda
${GGML_HEADERS_CUDA}
${GGML_SOURCES_CUDA}
)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_GRAPHS)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
if (GGML_CUDA_FORCE_CUBLAS)
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
endif()
if (GGML_CUDA_NO_VMM)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
if (GGML_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
if (GGML_STATIC)
if (WIN32)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
endif()
endif()
else()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
endif()
if (GGML_CUDA_NO_VMM)
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
endif()
set(CUDA_CXX_FLAGS "")
set(CUDA_FLAGS -use_fast_math -extended-lambda)
if (GGML_CUDA_DEBUG)
list(APPEND CUDA_FLAGS -lineinfo)
add_compile_definitions(GGML_CUDA_DEBUG)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# Options are:
# - none (not recommended)
# - speed (nvcc's default)
# - balance
# - size
list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
endif()
if (GGML_FATAL_WARNINGS)
list(APPEND CUDA_FLAGS -Werror all-warnings)
endif()
if (GGML_ALL_WARNINGS AND NOT MSVC)
set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
endif()
execute_process(
COMMAND ${NVCC_CMD} -Xcompiler --version
OUTPUT_VARIABLE CUDA_CCFULLVER
ERROR_QUIET
)
if (NOT CUDA_CCFULLVER MATCHES clang)
set(CUDA_CCID "GNU")
execute_process(
COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
OUTPUT_VARIABLE CUDA_CCVER
ERROR_QUIET
OUTPUT_STRIP_TRAILING_WHITESPACE
)
else()
if (CUDA_CCFULLVER MATCHES Apple)
set(CUDA_CCID "AppleClang")
else()
set(CUDA_CCID "Clang")
endif()
string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
endif()
message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER})
list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
endif()
if (NOT MSVC)
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
else()
# CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
# https://github.com/NVIDIA/cccl/pull/6827
list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
endif()
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
endif()
target_compile_options(ggml-cuda PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
else()
message(FATAL_ERROR "CUDA Toolkit not found")
endif()

61
ggml/src/ggml-cuda/acc.cu Normal file
View File

@@ -0,0 +1,61 @@
#include "acc.cuh"
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
int64_t src1_idx = i - offset;
int64_t tmp = src1_idx;
const int64_t i13 = tmp / s13;
tmp -= i13 * s13;
const int64_t i12 = tmp / s12;
tmp -= i12 * s12;
const int64_t i11 = tmp / s11;
tmp -= i11 * s11;
const int64_t i10 = tmp;
float val = x[i];
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
}
dst[i] = val;
}
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
}
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
const int64_t s1 = dst->op_params[0] / sizeof(float);
const int64_t s2 = dst->op_params[1] / sizeof(float);
const int64_t s3 = dst->op_params[2] / sizeof(float);
const int64_t offset = dst->op_params[3] / sizeof(float);
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ACC_BLOCK_SIZE 256
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,58 @@
#include "add-id.cuh"
static __global__ void add_id_kernel(
const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne0, int64_t ne1,
size_t nb01, size_t nb02,
size_t nb11,
size_t nb21
) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21);
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
const float * src1_row = (const float *)((const char *)src1 + i11*nb11);
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const int32_t * src2_d = (const int32_t *)src2->data;
float * dst_d = (float *)dst->data;
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src1_d, src2_d, dst_d,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,34 @@
#include "arange.cuh"
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
dst[nidx] = start + step * nidx;
}
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
}
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start;
float stop;
float step;
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
int64_t steps = (int64_t)ceil((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);
arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ARANGE_BLOCK_SIZE 256
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,91 @@
#include <algorithm>
#include <cstdint>
#include "argmax.cuh"
#include "common.cuh"
#include "sum.cuh"
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
const int64_t row = blockIdx.x;
float maxval = -FLT_MAX;
int argmax = -1;
const float * rowx = x + row * ncols;
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
const float val = rowx[col];
if (val > maxval) {
maxval = val;
argmax = col;
}
}
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
const int n_warps = blockDim.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
if (n_warps > 1) {
constexpr int max_warps = 1024 / WARP_SIZE;
__shared__ float shared_maxval[max_warps];
__shared__ int shared_argmax[max_warps];
if (lane_id == 0) {
shared_maxval[warp_id] = maxval;
shared_argmax[warp_id] = argmax;
}
__syncthreads();
if (warp_id == 0) {
if (lane_id < n_warps) {
maxval = shared_maxval[lane_id];
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
}
}
if (warp_id == 0 && lane_id == 0) {
dst[row] = argmax;
}
}
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const float * src0_d = (const float *) src0->data;
int32_t * dst_d = (int32_t *) dst->data;
cudaStream_t stream = ctx.stream();
const int64_t num_blocks = nrows;
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
const dim3 blocks_dim(num_threads, 1, 1);
const dim3 blocks_num(num_blocks, 1, 1);
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,221 @@
#include "argsort.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
const int col = blockIdx.x * blockDim.x + threadIdx.x;
const int row = blockIdx.y;
if (col < ncols && row < nrows) {
indices[row * ncols + col] = col;
}
}
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
}
} else {
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
} else {
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
}
}
}
#endif // GGML_CUDA_USE_CUB
// Bitonic sort implementation
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
}
const float * x_row = x + row * ncols;
extern __shared__ int dst_row[];
// initialize indices
dst_row[col] = col;
__syncthreads();
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
#ifdef GGML_CUDA_USE_CUB
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
if (shared_mem > max_shared_mem || ncols > 1024) {
ggml_cuda_pool & pool = ctx.pool();
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
}
#else
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
}

View File

@@ -0,0 +1,19 @@
#include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);
#endif // GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);

View File

@@ -0,0 +1,502 @@
#include "binbcast.cuh"
#include <cstdint>
#include <utility>
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
}
static __device__ __forceinline__ float op_add(const float a, const float b) {
return a + b;
}
static __device__ __forceinline__ float op_sub(const float a, const float b) {
return a - b;
}
static __device__ __forceinline__ float op_mul(const float a, const float b) {
return a * b;
}
static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
return;
}
const uint32_t i11 = fastmodulo(i1, ne11);
const uint32_t i12 = fastmodulo(i2, ne12);
const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result;
}
}
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = fastdiv(i, prod_012);
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return;
}
const int i11 = fastmodulo(i1, ne11);
const int i12 = fastmodulo(i2, ne12);
const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result;
}
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream, std::index_sequence<I...>) {
GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10 / ne0;
int nr1 = ne11 / ne1;
int nr2 = ne12 / ne2;
int nr3 = ne13 / ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
int64_t cne[] = { ne0, ne1, ne2, ne3 };
int64_t cne0[] = { ne00, ne01, ne02, ne03 };
int64_t cne1[] = { ne10, ne11, ne12, ne13 };
size_t cnb[] = { nb0, nb1, nb2, nb3 };
size_t cnb0[] = { nb00, nb01, nb02, nb03 };
size_t cnb1[] = { nb10, nb11, nb12, nb13 };
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
}
}
}
}
template <typename T>
static __global__ void k_repeat_back(
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
const int64_t tid2 = tid23 % ne2;
const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
}
}
}
}
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template <float (*bin_op)(const float, const float), int n_fuse = 1>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
}
};
template <typename T>
static void repeat_back_cuda(
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
}
template<class op>
static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
template <float (*op)(const float, const float), int n_fuse>
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
cudaStream_t stream = ctx.stream();
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else {
fprintf(stderr,
"%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
__func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
switch (n_fuse) {
case 2:
ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
break;
case 3:
ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
break;
case 4:
ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
break;
case 5:
ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
break;
case 6:
ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
break;
case 7:
ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
break;
case 8:
ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
break;
default:
GGML_ASSERT(false && "Unsupported n_fuse value");
}
}
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_can_repeat(dst, src0));
cudaStream_t stream = ctx.stream();
GGML_TENSOR_UNARY_OP_LOCALS;
GGML_ASSERT(ne2*ne3 <= (1 << 15));
const size_t ts = ggml_type_size(src0->type);
const size_t s00 = nb00 / ts;
const size_t s01 = nb01 / ts;
const size_t s02 = nb02 / ts;
const size_t s03 = nb03 / ts;
switch (dst->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
} break;
default: {
GGML_ASSERT(false);
} break;
}
}

View File

@@ -0,0 +1,11 @@
#include "common.cuh"
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

View File

@@ -0,0 +1,45 @@
#include "clamp.cuh"
static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
return fminf(fmaxf(x, min), max);
}
template <class T>
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
}
template <class T>
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
float min;
float max;
memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
if (src0->type == GGML_TYPE_F16) {
clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
} else {
clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CLAMP_BLOCK_SIZE 256
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,221 @@
#include "concat.cuh"
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00) { // src0
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * gridDim.y;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
(nidx - ne00) +
blockIdx.y * (ne0 - ne00) +
blockIdx.z * (ne0 - ne00) * gridDim.y;
dst[offset_dst] = y[offset_src];
}
}
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.y < (unsigned)ne01) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * ne01;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx +
(blockIdx.y - ne01) * ne0 +
blockIdx.z * ne0 * (gridDim.y - ne01);
dst[offset_dst] = y[offset_src];
}
}
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.z < (unsigned)ne02) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx +
blockIdx.y * ne0 +
(blockIdx.z - ne02) * ne0 * gridDim.y;
dst[offset_dst] = y[offset_src];
}
}
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2);
if (dim == 0) {
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
return;
}
if (dim == 1) {
concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
return;
}
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
}
// non-contiguous kernel (slow)
template <int dim>
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
concat_f32_non_cont(
const char * src0,
const char * src1,
char * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne03,
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
uint64_t nb03,
int64_t /*ne10*/,
int64_t /*ne11*/,
int64_t /*ne12*/,
int64_t /*ne13*/,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
uint64_t nb13,
int64_t ne0,
int64_t /*ne1*/,
int64_t /*ne2*/,
int64_t /*ne3*/,
uint64_t nb0,
uint64_t nb1,
uint64_t nb2,
uint64_t nb3){
static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const float * x;
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
if constexpr (dim == 0) {
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
} else if constexpr (dim == 1) {
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
} else if constexpr (dim == 2) {
x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
} else if constexpr (dim == 3) {
x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
}
}
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = *x;
}
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
if (dim != 3) {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
src1_d + i3 * (src1->nb[3] / 4),
dst_d + i3 * ( dst->nb[3] / 4),
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
} else {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
auto launch_kernel = [&](auto dim) {
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
};
switch (dim) {
case 0:
launch_kernel(std::integral_constant<int, 0>{});
break;
case 1:
launch_kernel(std::integral_constant<int, 1>{});
break;
case 2:
launch_kernel(std::integral_constant<int, 2>{});
break;
case 3:
launch_kernel(std::integral_constant<int, 3>{});
break;
default:
GGML_ABORT("Invalid dim: %d", dim);
break;
}
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CONCAT_BLOCK_SIZE 256
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,86 @@
#include "conv-transpose-1d.cuh"
static __global__ void conv_transpose_1d_kernel(
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, const float * src1, float * dst) {
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
if (global_index >= output_size) {
return;
}
int out_index = global_index / dst_ne0;
float accumulator = 0;
for (int c = 0; c < src0_ne2; c++) {
int idx = global_index % dst_ne0;
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
int input_offset = src1_ne0 * c;
for (int i = 0; i < src1_ne0; i++) {
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
continue;
}
int weight_idx = idx - i*s0;
float kernel_weight = src0[kernel_offset + weight_idx];
float input_value = src1[input_offset+i];
accumulator += kernel_weight * input_value;
}
}
dst[global_index] = accumulator;
GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2);
}
static void conv_transpose_1d_f32_f32_cuda(
const int s0, const int p0, const int d0, const int output_size,
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
const float * src0, const float * src1, float * dst,
cudaStream_t stream) {
const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
s0,p0,d0,output_size,
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
src0,src1, dst);
}
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const int32_t * opts = (const int32_t *)dst->op_params;
const int s0 = opts[0];
const int p0 = 0;//opts[3];
const int d0 = 1;//opts[4];
const int64_t output_size = ggml_nelements(dst);
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
src0_d, src1_d, dst_d, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,161 @@
#include "conv2d-dw.cuh"
struct conv_params {
int in_w, in_h;
int out_w, out_h;
int kernel_w, kernel_h;
int stride_x, stride_y;
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
};
struct kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
kernel_bounds bounds;
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
y * params.out_w + x;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
out_x = global_idx % params.out_w;
out_y = (global_idx / params.out_w) % params.out_h;
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
n = global_idx / (params.out_w * params.out_h * params.channels);
}
};
struct cwhn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return (ky * params.kernel_w + kx) * params.channels + c;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
x * params.channels + c;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
c = global_idx % params.channels;
out_x = (global_idx / params.channels) % params.out_w;
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
n = global_idx / (params.channels * params.out_w * params.out_h);
}
};
template <typename T, typename Layout>
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
const int in_w, const int in_h, const int out_w, const int out_h,
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
const int channels, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = batches * channels * out_h * out_w;
if (global_idx >= total_elements) {
return;
}
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
int batch_idx, channel_idx, out_y_idx, out_x_idx;
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
T accumulator = 0;
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
accumulator += input_val * kernel_val;
}
}
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
}
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * w_d = (const float *) kernel->data;
const float * x_d = (const float *) input->data;
float * y_d = (float *) dst->data;
const int32_t * p = (const int32_t *) dst->op_params;
const int stride_x = p[0];
const int stride_y = p[1];
const int padding_x = p[2];
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int in_w = input->ne[0];
const int in_h = input->ne[1];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int out_w = dst->ne[0];
const int out_h = dst->ne[1];
const int channels = dst->ne[2];
const int batches = dst->ne[3];
cudaStream_t st = ctx.stream();
const int total = batches * channels * out_h * out_w;
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
if (ggml_is_contiguous(input)) {
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else {
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,91 @@
#include <algorithm>
#include "conv2d-transpose.cuh"
#include "ggml.h"
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = out_w * out_h * c_out * batches;
if (global_idx >= total_elements) {
return;
}
const int out_x_idx = global_idx % out_w;
const int out_y_idx = (global_idx / out_w) % out_h;
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
const int n_idx = global_idx / (out_w * out_h * c_out);
float accumulator = 0;
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
in_y /= stride;
if (in_y >= in_h) continue;
for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
in_x /= stride;
if (in_x >= in_w) continue;
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];
accumulator += input_val * (float) kern_val;
}
}
}
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
}
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;
const int input_w = input->ne[0];
const int input_h = input->ne[1];
const int output_w = dst->ne[0];
const int output_h = dst->ne[1];
const int channels_in = input->ne[2];
const int channels_out = kernel->ne[2];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int stride = dst->op_params[0];
const int batches = input->ne[3];
GGML_ASSERT(channels_in == kernel->ne[3]);
GGML_ASSERT(stride > 0);
cudaStream_t st = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(input));
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));
const int total = (output_w * output_h * channels_out * batches);
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
}

View File

@@ -0,0 +1,4 @@
#include "common.cuh"
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,166 @@
#include "conv2d.cuh"
#include "convert.cuh"
struct conv_params {
const int64_t IW, IH;
const int64_t OW, OH;
const int64_t KW, KH;
const int64_t ST_X, ST_Y;
const int64_t PD_X, PD_Y;
const int64_t DL_X, DL_Y;
const int64_t IC, OC;
const int64_t B;
const int64_t TOTAL;
};
struct kernel_bounds {
int64_t y_min, y_max;
int64_t x_min, x_max;
};
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
return (a > b) ? a : b;
}
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
return (a < b) ? a : b;
}
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
kernel_bounds bounds;
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
int64_t kern_coord,
int64_t stride,
int64_t dilation,
int64_t padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
}
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
}
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
}
__device__ static void unpack_indices(int64_t global_idx,
const conv_params & P,
int64_t & n,
int64_t & c,
int64_t & out_y,
int64_t & out_x) {
out_x = global_idx % P.OW;
out_y = (global_idx / P.OW) % P.OH;
c = (global_idx / (P.OW * P.OH)) % P.OC;
n = global_idx / (P.OW * P.OH * P.OC);
}
};
template <typename T, typename Layout>
static __global__ void conv2d_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const conv_params P) {
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= P.TOTAL) {
return;
}
int64_t n, c_out, out_y, out_x;
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
float acc = 0.0f;
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
}
}
}
// [N, OC, OH, OW]
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
}
template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
float * K_D = (float *) kernel->data;
const float * X_D = (const float *) input->data;
float * Y_D = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
// same number of input channels
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
cudaStream_t st = ctx.stream();
const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0]; // stride_x
const int ST_Y = p[1]; // stride_y
const int PD_X = p[2]; // padding_x
const int PD_Y = p[3]; // padding_y
const int DL_X = p[4]; // dilation_x
const int DL_Y = p[5]; // dilation_y
// No cwhn
GGML_ASSERT(p[6] == false);
const int IW = input->ne[0]; // input_w
const int IH = input->ne[1]; // input_h
const int OW = dst->ne[0]; // output_w
const int OH = dst->ne[1]; // output_h
const int KW = kernel->ne[0]; // kernel_w
const int KH = kernel->ne[1]; // kernel_h
const int IC = input->ne[2]; // input_channels
const int OC = kernel->ne[3]; // ouptut_chanles
const int B = input->ne[3]; // n_batches
const int64_t total = B * OC * OH * OW;
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_BLOCK_SIZE 256
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,825 @@
#include "convert.cuh"
#include "dequantize.cuh"
#include <cstdint>
#define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
if (i00 >= ne00) {
return;
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
const int * x0 = ((int *) vx) + blockIdx.x * nint;
half2 * y2 = (half2 *) (y + i0);
__shared__ int vals[nint];
#pragma unroll
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
break;
}
const int ix = ix0 + threadIdx.x;
vals[ix] = x0[ix];
}
__syncthreads();
#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
return;
}
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
const half d = *b0;
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
}
#else
GGML_UNUSED_VARS(vx, y, k);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
}
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int64_t i = blockIdx.x;
// assume 32 threads
const int64_t tid = threadIdx.x;
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
const float d = __half2float(x->d);
const float dm = -8*d;
const uint8_t * q = x->qs + 4*il;
for (int l = 0; l < 4; ++l) {
y[l+ 0] = d * (q[l] & 0xF) + dm;
y[l+16] = d * (q[l] >> 4) + dm;
}
}
template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int64_t i = blockIdx.x;
// assume 32 threads
const int64_t tid = threadIdx.x;
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
const float2 d = __half22float2(x->dm);
const uint8_t * q = x->qs + 4*il;
for (int l = 0; l < 4; ++l) {
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
y[l+16] = d.x * (q[l] >> 4) + d.y;
}
}
//================================== k-quants
template<typename dst_t>
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
const int64_t tid = threadIdx.x;
const int64_t n = tid/32;
const int64_t l = tid - 32*n;
const int64_t is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
}
template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
const int64_t r = threadIdx.x/4;
const int64_t tid = r/2;
const int64_t is0 = r%2;
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
const int64_t n = tid / 4;
const int64_t j = tid - 4*n;
uint8_t m = 1 << (4*n + j);
int64_t is = 8*n + 2*j + is0;
int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
float d_all = x[i].d;
float dl = d_all * (us - 32);
dst_t * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
}
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
template<typename dst_t>
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int64_t i = blockIdx.x;
// assume 32 threads
const int64_t tid = threadIdx.x;
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t is = 2*il;
const int64_t n = 4;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint8_t * q = x[i].qs + 32*il + n*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
}
template<typename dst_t>
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx;
const int64_t i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const int64_t tid = threadIdx.x;
const int64_t il = tid/16; // il is in 0...3
const int64_t ir = tid%16; // ir is in 0...15
const int64_t is = 2*il; // is is in 0...6
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
uint8_t hm = 1 << (2*il);
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
}
template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int64_t i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const int64_t tid = threadIdx.x;
const int64_t ip = tid/32; // ip is 0 or 1
const int64_t il = tid - 32*ip; // 0...32
const int64_t is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il;
const float d = x[i].d;
const uint8_t * ql = x[i].ql + 64*ip + il;
const uint8_t qh = x[i].qh[32*ip + il];
const int8_t * sc = x[i].scales + is;
y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * aux8 = (const uint8_t *)q2;
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq2_xs * x = (const block_iq2_xs *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq2_s * x = (const block_iq2_s *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template<typename dst_t>
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * q3 = x[i].qs + 8*ib;
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq3_s * x = (const block_iq3_s *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq1_s * x = (const block_iq1_s *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq1_m * x = (const block_iq1_m *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * sc = (const uint16_t *)x[i].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d;
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq4_xs * x = (const block_iq4_xs *)vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
template<typename dst_t>
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
}
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
} else {
const bool need_check = true;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
}
}
template<typename dst_t>
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i00 >= ne00) {
return;
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
}
template <typename src_t, typename dst_t>
static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
}
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
default:
return nullptr;
}
}
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
default:
return nullptr;
}
}
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ1_M:
return dequantize_row_iq1_m_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
default:
return nullptr;
}
}
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
default:
return nullptr;
}
}
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_F16:
return convert_unary_cuda<half, nv_bfloat16>;
default:
return nullptr;
}
}
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F16:
return convert_unary_cuda<half, float>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>;
default:
return nullptr;
}
}

View File

@@ -0,0 +1,56 @@
#pragma once
#include "common.cuh"
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
template<typename T>
using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
// TODO more general support for non-contiguous inputs
template<typename T>
using to_t_nc_cuda_t = void (*)(const void * x, T * y,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
template<typename dst_t, typename src_t>
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
if constexpr (std::is_same_v<dst_t, src_t>) {
return x;
} else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP
return __float22bfloat162_rn(x);
#else
return {x.x, x.y};
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
return float(x);
}
}

View File

@@ -0,0 +1,64 @@
#include "common.cuh"
#include "count-equal.cuh"
#include <cstdint>
template <typename T>
static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
const int64_t i0 = (int64_t) blockIdx.x*dk;
const int64_t i1 = min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (threadIdx.x != 0) {
return;
}
atomicAdd((int *) dst, nequal);
}
void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
cudaStream_t stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int * src0_d = (const int *) src0->data;
const int * src1_d = (const int *) src1->data;
count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
} break;
default:
GGML_ASSERT(false);
break;
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,57 @@
// Simplified API for asynchronous data loading.
#include "common.cuh"
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
#ifdef CP_ASYNC_AVAILABLE
return __cvta_generic_to_shared(generic_ptr);
#else
GGML_UNUSED(generic_ptr);
NO_DEVICE_CODE;
return 0;
#endif // CP_ASYNC_AVAILABLE
}
// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
template <int preload>
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
#ifdef CP_ASYNC_AVAILABLE
#if CUDART_VERSION >= 11040
if (preload == 256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
GGML_UNUSED(dst);
GGML_UNUSED(src);
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
// Makes each thread wait until its asynchronous data copies are done.
// This does NOT provide any additional synchronization.
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
static __device__ __forceinline__ void cp_async_wait_all() {
#ifdef CP_ASYNC_AVAILABLE
asm volatile("cp.async.wait_all;");
#else
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}

View File

@@ -0,0 +1,217 @@
#pragma once
#include "ggml-common.h"
#include "convert.cuh"
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = x[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (x[0 + j] - vmin)*id;
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
float min = x[0];
float max = x[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = x[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (x[0 + j] - min)*id;
const float x1 = (x[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[j]*id;
y->qs[j] = roundf(x0);
}
}
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
y->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = x[0 + j]*x[0 + j];
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
// Wrapper functions for cpy.cu compatibility
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
template<typename src_t, typename dst_t>
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
}

555
ggml/src/ggml-cuda/cpy.cu Normal file
View File

@@ -0,0 +1,555 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#include "cpy-utils.cuh"
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
template <cpy_kernel_t cpy_1>
static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
template <typename T>
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);
const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
if (imat >= nmat)
break;
#pragma unroll
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
const int col = threadIdx.x * sizeof(float)/sizeof(T);
T *tile2 = reinterpret_cast<T*>(tile[row]);
tile2[col] = src[imat*n + (y+j)*ne01 + x];
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
if (ty + j < ne01 && tx < ne00) {
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
}
}
}
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
#pragma unroll
for (int j = 0; j < QK8_0; j += 2) {
float2 dq;
dequantize_q8_0(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + 1) = dq.y;
}
}
template<dequantize_kernel_t dequant, int qk>
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
#pragma unroll
for (int j = 0; j < qk/2; j++) {
float2 dq;
dequant(cxi, 0, j, dq);
*(cdstf + j) = dq.x;
*(cdstf + j + qk/2) = dq.y;
}
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
template<typename src_t, typename dst_t>
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}
template<typename src_t, typename dst_t>
static void ggml_cpy_scalar_contiguous_cuda(
const char * cx, char * cdst, const int64_t ne,
cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
}
template<typename src_t, typename dst_t, bool transposed = false>
static void ggml_cpy_scalar_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int64_t ne00n, ne01n, ne02n;
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
}
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
GGML_ASSERT(grid_x < UINT_MAX);
GGML_ASSERT(grid_y < USHRT_MAX);
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}
static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int64_t num_blocks = ne / QK8_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int64_t num_blocks = ne / QK4_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int64_t num_blocks = ne / QK4_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int64_t num_blocks = ne / QK5_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int64_t num_blocks = ne / QK5_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int64_t num_blocks = ne / QK4_NL;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
//GGML_ASSERT(src0->ne[3] == 1);
const int64_t nb00 = src0->nb[0];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t nb03 = src0->nb[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
//GGML_ASSERT(src1->ne[3] == 1);
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
const int64_t nb13 = src1->nb[3];
cudaStream_t main_stream = ctx.stream();
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
} else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_scalar_cuda<float, float>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<float, nv_bfloat16>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<float, half>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<float, half>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<half, half, true>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_scalar_cuda<half, half>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<half, nv_bfloat16>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<half, float>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<half, float>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<nv_bfloat16, half>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<nv_bfloat16, float>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_scalar_cuda<int32_t, int32_t>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<float, int32_t>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
if (contiguous_srcs) {
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
(src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_scalar_cuda<int32_t, float>
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
}
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
ggml_cuda_cpy(ctx, src0, dst);
}

View File

@@ -0,0 +1,7 @@
#include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,177 @@
#include "common.cuh"
#include "cross-entropy-loss.cuh"
#include "sum.cuh"
#include <cmath>
#include <cstdint>
template <bool use_shared>
static __global__ void cross_entropy_loss_f32(
const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
extern __shared__ float tmp[];
logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
// Find maximum for softmax:
float max_logit = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = logits[i];
max_logit = fmaxf(max_logit, val);
if (use_shared) {
tmp[i] = val;
}
}
max_logit = warp_reduce_max(max_logit);
// Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float logit_i = use_shared ? tmp[i] : logits[i];
sum += expf(logit_i - max_logit);
}
sum = warp_reduce_sum(sum);
sum = logf(sum);
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float logit_i = use_shared ? tmp[i] : logits[i];
loss += (logit_i - max_logit - sum) * labels[i];
}
loss = -warp_reduce_sum(loss) / (float)k;
if (threadIdx.x != 0) {
return;
}
dst[blockIdx.x] = loss;
}
template <bool use_shared>
static __global__ void cross_entropy_loss_back_f32(
const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
float * __restrict__ dst, const int nclasses) {
extern __shared__ float tmp[];
logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
dst += int64_t(blockIdx.x)*nclasses;
float maxval = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = logits[i];
maxval = fmaxf(maxval, val);
if (use_shared) {
tmp[i] = val;
}
}
maxval = warp_reduce_max(maxval);
float sum = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
sum += val;
if (use_shared) {
tmp[i] = val;
} else {
dst[i] = val;
}
}
sum = warp_reduce_sum(sum);
const float sm_scale = 1.0f/sum;
const float d_by_nrows = *grad/gridDim.x;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = use_shared ? tmp[i] : dst[i];
dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
}
}
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1);
const size_t nbytes_shared = ne00*sizeof(float);
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
if (nbytes_shared <= smpbo) {
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
}
CUDA_CHECK(cudaGetLastError());
// Combine results from individual blocks:
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
}
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0];
const ggml_tensor * src0f = dst->src[1];
const ggml_tensor * src1f = dst->src[2];
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT(src1f->type == GGML_TYPE_F32);
GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_scalar(grad));
GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
GGML_ASSERT(ggml_are_same_shape(src0f, dst));
const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0f);
const float * grad_d = (const float *) grad->data;
const float * src0f_d = (const float *) src0f->data;
const float * src1f_d = (const float *) src1f->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1);
const size_t nbytes_shared = ne00*sizeof(float);
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
}
}

View File

@@ -0,0 +1,7 @@
#include "common.cuh"
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,307 @@
#include <algorithm>
#include "cumsum.cuh"
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
static __global__ void cumsum_cub_kernel(
const T * __restrict__ src,
T * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
__shared__ typename BlockScanT::TempStorage temp_storage;
__shared__ T block_carry;
const int tid = threadIdx.x;
constexpr int UNROLL_FACTOR = 4;
constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int64_t i3 = blockIdx.z;
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
if (tid == 0) {
block_carry = 0;
}
__syncthreads();
for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
T items[UNROLL_FACTOR];
T thread_sum = T(0);
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
T val = (idx < ne00) ? src_row[idx] : T(0);
thread_sum += val;
items[i] = thread_sum;
}
// Block-wide scan on thread sums
T thread_prefix;
T block_total;
BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
__syncthreads();
// Add offset to each item and store
T thread_offset = thread_prefix - thread_sum + block_carry;
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
if (idx < ne00) {
dst_row[idx] = items[i] + thread_offset;
}
}
__syncthreads();
// Update carry for next tile
if (tid == 0) {
block_carry += block_total;
}
}
#else
NO_DEVICE_CODE;
#endif // GGML_CUDA_USE_CUB
}
// Fallback kernel implementation
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
GGML_UNUSED_VARS(s00, s0);
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int lane = tid % warp_size;
const int warp = tid / warp_size;
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
*s_carry = 0.0f;
}
__syncthreads();
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
// register blocking: process 4 elements per thread to hide latency
// and reduce synchronization overhead
constexpr int num_unroll = 4;
T temp[num_unroll];
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
int64_t idx = i + tid * num_unroll;
// thread local sequential scan
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
#pragma unroll
for (int64_t j = 1; j < num_unroll; j++) {
temp[j] = temp[j - 1];
if (idx + j < ne00) {
temp[j] += src_row[idx + j];
} else {
temp[j] += 0;
}
}
// last emenent is sum of all values assigned to thread
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
// Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
// Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w; // exclusive sum
}
if (tid == warps_per_block - 1) {
*s_chunk_total = inc; // total sum of this chunk
}
}
__syncthreads();
// write back results
float carry = *s_carry;
// calculate sum offset for this thread
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
#pragma unroll
for (int32_t j = 0; j < num_unroll; j++) {
if (idx + j < ne00) {
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
}
}
__syncthreads();
// Update carry for next chunk
if (tid == 0) {
*s_carry += *s_chunk_total;
}
}
}
#ifdef GGML_CUDA_USE_CUB
template <typename T>
static void cumsum_cub(ggml_cuda_pool & pool,
const T * src,
T * dst,
int64_t ne,
cudaStream_t stream) {
size_t tmp_size = 0;
// Query how much temp storage CUDA UnBound (CUB) needs
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
tmp_size, // reference to size (will be set by CUB)
src, // input pointer
dst, // output pointer
ne, // number of elements
stream // CUDA stream to use
);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
// Perform the inclusive scan
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
}
#endif // GGML_CUDA_USE_CUB
template<typename T>
static void cumsum_cuda(
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {
const size_t type_size = sizeof(T);
bool use_cub = false;
#ifdef GGML_CUDA_USE_CUB
// Check if we can use CUB (data must be contiguous along innermost dimension)
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
if (is_contiguous) {
use_cub = true;
const int64_t nrows = ne01 * ne02 * ne03;
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
for (int i=0; i<nrows; i++) {
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
}
return;
}
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
const int warp_size = info.warp_size;
const int num_warps = (ne00 + warp_size - 1) / warp_size;
int block_size = num_warps * warp_size;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
if (use_cub && ne00 >= 1024) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
ctx, (const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
/*case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;*/
default:
GGML_ABORT("fatal error");
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CUMSUM_BLOCK_SIZE 256
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,77 @@
#include "common.cuh"
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const float d = x[ib].d;
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
}
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const float2 dm = __half22float2(x[ib].dm);
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
v.x = (v.x * dm.x) + dm.y;
v.y = (v.y * dm.x) + dm.y;
}
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const float d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
}
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const float2 dm = __half22float2(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
v.x = (v.x * dm.x) + dm.y;
v.y = (v.y * dm.x) + dm.y;
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const float d = x[ib].d;
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
v.x *= d;
v.y *= d;
}

View File

@@ -0,0 +1,77 @@
#include "convert.cuh"
#include "diag.cuh"
#include "ggml.h"
template <typename T>
static __global__ void diag_kernel(T * __restrict__ dst,
const T * __restrict__ src,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
const int64_t total_elements) {
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= total_elements) {
return;
}
const int64_t i0 = global_idx % ne0;
const int64_t i1 = (global_idx / ne0) % ne1;
const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
if (i0 == i1) {
const int64_t batch_idx = i3 * ne2 + i2;
const int64_t src_idx = batch_idx * ne0 + i0;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = ggml_cuda_cast<T>(0);
}
GGML_UNUSED_VARS(ne3);
}
void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
void * dst_d = dst->data;
const void * src0_d = src0->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
GGML_ASSERT(ne00 == ne0);
GGML_ASSERT(ne01 == 1);
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne3);
const int64_t n_elems = ggml_nelements(dst);
const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
switch (dst->type) {
case GGML_TYPE_F32:
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
ne1, ne2, ne3, n_elems);
break;
case GGML_TYPE_F16:
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
ne1, ne2, ne3, n_elems);
break;
default:
GGML_ABORT("unsupported type");
}
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_DIAG_BLOCK_SIZE 256
void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,40 @@
#include "diagmask.cuh"
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int i = row*ncols + col;
//dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
//dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
const dim3 block_nums(nrows_x, block_num_x, 1);
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int nrows0 = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
#include "common.cuh"
#include "fattn-tile.cuh"
#include "fattn-wmma-f16.cuh"
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
switch (K->ne[0]) {
case 40: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
} break;
case 64: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
} break;
case 72: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
} break;
case 80: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
} break;
case 96: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
} break;
case 112: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
} break;
case 128: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
case 256: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,586 @@
#include "common.cuh"
#include "fattn-common.cuh"
static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
return 128;
GGML_UNUSED(cc);
}
static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
return 128;
}
// Currenlty llvm with the amdgcn target does not support unrolling loops
// that contain a break that can not be resolved at compile time.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
static __global__ void flash_attn_ext_vec(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
#ifdef GGML_USE_HIP
#ifdef RDNA
constexpr int nthreads_KQ_q = 2;
#else
constexpr int nthreads_KQ_q = 4;
#endif // RDNA
constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
#else
constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
#endif // GGML_USE_HIP
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
#endif // V_DOT2_F32_F16_AVAILABLE
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = nthreads / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
__builtin_assume(tid < nthreads);
constexpr int ne_KQ = ncols*D;
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#else
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#endif // V_DOT2_F32_F16_AVAILABLE
float KQ_max[ncols];
float KQ_sum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max[j] = -FLT_MAX/2.0f;
KQ_sum[j] = 0.0f;
}
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
__align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
if constexpr (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 1 && ic0 + j >= int(ne01.z)) {
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0;
}
}
if (threadIdx.x < D/QK8_1) {
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
}
} else {
const float * Q_f = (const float *) (Q + j*nb01);
constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
quantize_q8_1_to_shared<float2, nthreads_quantize>
(Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
}
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
}
}
__syncthreads();
} else {
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const float2 * Q_j = (const float2 *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
__align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
}
#pragma unroll
for (int i1 = 0; i1 < cpy_ne; ++i1) {
Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
}
}
#pragma unroll
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
Q_reg[j][k] *= scale_h2;
}
}
#else
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const float2 * Q_j = (const float2 *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
}
}
#pragma unroll
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
Q_reg[j][k].x *= scale;
Q_reg[j][k].y *= scale;
}
}
#endif // V_DOT2_F32_F16_AVAILABLE
}
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*nthreads * nb11;
V += blockIdx.y*nthreads * nb21;
maskh += blockIdx.y*nthreads;
for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
// Increment pointers after each loop:
K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
// Calculate KQ tile and keep track of new maximum KQ values:
float KQ_reg[ncols]; // KQ in registers.
float KQ_max_new[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max_new[j] = KQ_max[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum<nthreads_KQ>(sum);
if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
}
if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
}
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum;
}
}
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
}
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // V_DOT2_F32_F16_AVAILABLE
}
#ifndef GGML_USE_HIP
__syncwarp();
#endif // GGML_USE_HIP
#pragma unroll
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
half2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
}
}
}
#else
float KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = KQ[j*nthreads + k];
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
float2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
}
}
}
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
if (sinks && blockIdx.y == 0) {
const float sink = ((const float *) sinks)[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
KQ_max[j] = kqmax_new_j;
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
__shared__ float KQ_max_shared[ncols][WARP_SIZE];
__shared__ float KQ_sum_shared[ncols][WARP_SIZE];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.y == 0) {
KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
KQ_sum_shared[j][threadIdx.x] = 0.0f;
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
KQ_max_shared[j][threadIdx.y] = KQ_max[j];
}
}
__syncthreads();
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
break;
}
float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
kqmax_new = warp_reduce_max(kqmax_new);
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
}
#else
float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
#endif // V_DOT2_F32_F16_AVAILABLE
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
if (threadIdx.x == 0) {
KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
}
__syncthreads();
if (nthreads <= D || tid < D) {
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nthreads) {
float dst_val = 0;
#pragma unroll
for (int w = 0; w < nwarps; ++w) {
#pragma unroll
for (int v = 0; v < V_cols_per_iter; ++v) {
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
}
}
if (gridDim.y == 1) {
dst_val /= KQ_sum[j_VKQ];
}
dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
}
}
if (j_VKQ < ncols-1) {
__syncthreads();
}
}
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
const int nwarps = nthreads / WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
const bool need_f16_K = type_K == GGML_TYPE_F16;
const bool need_f16_V = type_V == GGML_TYPE_F16;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}
#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_case \
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)

View File

@@ -0,0 +1,675 @@
// Old and deprecated WMMA FlashAttention implementation.
// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"
#ifdef GGML_USE_WMMA_FATTN
#if !defined(GGML_USE_HIP)
#include <mma.h>
#if defined(GGML_USE_MUSA)
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
#elif defined(GGML_USE_HIP)
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP)
#endif // GGML_USE_WMMA_FATTN
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * sinksf = (const float *) sinks;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
float KQ_max_f[ncols/nwarps];
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_f[j] = -FLT_MAX/2.0f;
}
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max_h2[ncols/nwarps];
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
__syncthreads();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
}
}
__syncthreads();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
}
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale_f[j0/nwarps] = 0.0f;
}
KQ_max_f[j0/nwarps] = KQ_max_new;
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/warp_size] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/warp_size] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
}
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/warp_size] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
}
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/warp_size] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/warp_size];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
}
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
}
}
__syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
}
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 VKQ_scale;
if (std::is_same<KQ_acc_t, float>::value) {
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
} else {
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
}
}
__syncthreads();
}
// Apply attention sinks
if (sinksf && blockIdx.y == 0) {
const float sinkf = sinksf[head];
const half sinkh = __float2half(sinkf);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
KQ_max_f[j0/nwarps] = kqmax_new;
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= scale_h2;
}
} else {
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
half kqmax_new = fmaxf(kqmax_old, sinkh);
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
const half val = hexp(sinkh - kqmax_new);
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
}
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= int(ne01.z)) {
return;
}
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
} else {
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst_unrolled*D + i] = dst_val;
}
if (gridDim.y == 1 || threadIdx.x != 0) {
continue;
}
float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) {
dst_meta_val.x = KQ_max_f[j0/nwarps];
} else {
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
}
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
} else {
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
break;
}
}
return;
}
#if !defined(GGML_USE_HIP)
if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
#endif // !defined(GGML_USE_HIP)
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

View File

@@ -0,0 +1,51 @@
#pragma once
#include "common.cuh"
#if defined(GGML_USE_MUSA)
#define GGML_USE_WMMA_FATTN
#endif // defined(GGML_USE_MUSA)
#if defined(GGML_HIP_ROCWMMA_FATTN)
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
#define GGML_USE_WMMA_FATTN
#elif defined(CDNA)
#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
#if defined(RDNA3)
#define GGML_USE_WMMA_FATTN
#endif // defined(RDNA3)
#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
#define GGML_USE_WMMA_FATTN
#elif defined(RDNA4)
#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false;
#else
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_CDNA(cc)){
#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
} else {
return false;
}
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

415
ggml/src/ggml-cuda/fattn.cu Normal file
View File

@@ -0,0 +1,415 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-mma-f16.cuh"
#include "fattn-tile.cuh"
#include "fattn-vec.cuh"
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * Q = dst->src[0];
if constexpr (ncols2 <= 8) {
if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
return;
}
}
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
}
template <int DKQ, int DV>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
// are put into the template specialization without GQA optimizations.
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr) {
continue;
}
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (t->nb[i] % 16 != 0) {
use_gqa_opt = false;
break;
}
}
}
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
switch (Q->ne[0]) {
case 64:
GGML_ASSERT(V->ne[0] == 64);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
break;
case 80:
GGML_ASSERT(V->ne[0] == 80);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
break;
case 96:
GGML_ASSERT(V->ne[0] == 96);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
break;
case 112:
GGML_ASSERT(V->ne[0] == 112);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
break;
case 128:
GGML_ASSERT(V->ne[0] == 128);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(use_gqa_opt);
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} break;
default:
GGML_ABORT("fatal error");
break;
}
}
#define FATTN_VEC_CASE(D, type_K, type_V) \
{ \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \
} \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \
FATTN_VEC_CASE(128, type_K, type_V) \
FATTN_VEC_CASE(256, type_K, type_V) \
static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
#ifdef GGML_CUDA_FA_ALL_QUANTS
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#endif // GGML_CUDA_FA_ALL_QUANTS
GGML_ABORT("fatal error");
}
// Best FlashAttention kernel for a specific GPU:
enum best_fattn_kernel {
BEST_FATTN_KERNEL_NONE = 0,
BEST_FATTN_KERNEL_TILE = 200,
BEST_FATTN_KERNEL_VEC = 100,
BEST_FATTN_KERNEL_WMMA_F16 = 300,
BEST_FATTN_KERNEL_MMA_F16 = 400,
};
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
#ifndef FLASH_ATTN_AVAILABLE
GGML_UNUSED(device); GGML_UNUSED(dst);
return BEST_FATTN_KERNEL_NONE;
#endif// FLASH_ATTN_AVAILABLE
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr) {
continue;
}
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (t->nb[i] % 16 != 0) {
gqa_opt_applies = false;
break;
}
}
}
const int cc = ggml_cuda_info().devices[device].cc;
switch (K->ne[0]) {
case 40:
case 64:
case 72:
case 80:
case 96:
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 576:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
#ifndef GGML_CUDA_FA_ALL_QUANTS
if (K->type != V->type) {
return BEST_FATTN_KERNEL_NONE;
}
#endif // GGML_CUDA_FA_ALL_QUANTS
switch (K->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
#ifndef GGML_CUDA_FA_ALL_QUANTS
return BEST_FATTN_KERNEL_NONE;
#endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
if (mask && mask->ne[2] != 1) {
return BEST_FATTN_KERNEL_NONE;
}
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores are available, use them:
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (Q->ne[1] == 1) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
if (!gqa_opt_applies && Q->ne[1] == 1) {
return BEST_FATTN_KERNEL_VEC;
}
}
return BEST_FATTN_KERNEL_MMA_F16;
}
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
if (Q->ne[1] * gqa_ratio_eff <= 16) {
return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
}
return BEST_FATTN_KERNEL_MMA_F16;
}
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
return BEST_FATTN_KERNEL_WMMA_F16;
}
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
if (Q->ne[1] * gqa_ratio_eff <= 8) {
return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
}
return BEST_FATTN_KERNEL_MMA_F16;
}
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
return BEST_FATTN_KERNEL_TILE;
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_set_device(ctx.device);
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
case BEST_FATTN_KERNEL_NONE:
GGML_ABORT("fatal error");
case BEST_FATTN_KERNEL_TILE:
ggml_cuda_flash_attn_ext_tile(ctx, dst);
break;
case BEST_FATTN_KERNEL_VEC:
ggml_cuda_flash_attn_ext_vec(ctx, dst);
break;
case BEST_FATTN_KERNEL_WMMA_F16:
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
break;
case BEST_FATTN_KERNEL_MMA_F16:
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
break;
}
}
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);

View File

@@ -0,0 +1,37 @@
#include "fill.cuh"
#include "convert.cuh"
#define CUDA_FILL_BLOCK_SIZE 256
template <typename T>
static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = value;
}
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(dst));
float value;
memcpy(&value, dst->op_params, sizeof(float));
const int64_t k = ggml_nelements(dst);
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
switch (dst->type) {
case GGML_TYPE_F32:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
break;
case GGML_TYPE_F16:
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
break;
default:
GGML_ABORT("unsupported type");
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,286 @@
#include "getrows.cuh"
#include "dequantize.cuh"
#include "convert.cuh"
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x;
const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
}
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x;
const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
}
}
}
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
static void get_rows_cuda_q(
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
const size_t s1 = nb1 / sizeof(dst_t);
const size_t s2 = nb2 / sizeof(dst_t);
const size_t s3 = nb3 / sizeof(dst_t);
const size_t s10 = nb10 / sizeof(int32_t);
const size_t s11 = nb11 / sizeof(int32_t);
const size_t s12 = nb12 / sizeof(int32_t);
// const size_t s13 = nb13 / sizeof(int32_t);
GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
/*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
}
template<typename src0_t, typename dst_t>
static void get_rows_cuda_float(
const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
const size_t s1 = nb1 / sizeof(dst_t);
const size_t s2 = nb2 / sizeof(dst_t);
const size_t s3 = nb3 / sizeof(dst_t);
const size_t s10 = nb10 / sizeof(int32_t);
const size_t s11 = nb11 / sizeof(int32_t);
const size_t s12 = nb12 / sizeof(int32_t);
// const size_t s13 = nb13 / sizeof(int32_t);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
/*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
}
template <typename dst_t>
static void ggml_cuda_get_rows_switch_src0_type(
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
switch (src0_type) {
case GGML_TYPE_F16:
get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_F32:
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_BF16:
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_1:
get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q5_0:
get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q5_1:
get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q8_0:
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
default:
// TODO: k-quants
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
break;
}
}
void get_rows_cuda(
const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
size_t nb1, size_t nb2, size_t nb3,
cudaStream_t stream) {
switch (dst_type) {
case GGML_TYPE_F32:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_F16:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_BF16:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
default:
GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
break;
}
}
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
cudaStream_t stream = ctx.stream();
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
}
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
GGML_TENSOR_BINARY_OP_LOCALS
const float * src0_d = (const float *) src0->data;
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne02*ne03 == 1);
GGML_ASSERT(ne12*ne13 == 1);
GGML_ASSERT(ne2*ne3 == 1);
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
}

View File

@@ -0,0 +1,15 @@
#include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
void get_rows_cuda(
const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
size_t nb1, size_t nb2, size_t nb3,
cudaStream_t stream);
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

93
ggml/src/ggml-cuda/gla.cu Normal file
View File

@@ -0,0 +1,93 @@
#include "common.cuh"
#include "gla.cuh"
template<int HEAD_SIZE>
static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = HEAD_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4 & k = (float4 &)(_k[j]);
const float4 & r = (float4 &)(_r[j]);
const float4 & td = (float4 &)(_td[j]);
float4 & s = (float4 &)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
y += r.x * s.x;
y += r.y * s.y;
y += r.z * s.z;
y += r.w * s.w;
}
dst[t] = y * scale;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * td_d = (const float *)dst->src[3]->data;
const float * s_d = (const float *)dst->src[4]->data;
const int64_t B = dst->src[4]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
float scale;
memcpy(&scale, (float*)dst->op_params, sizeof(float));
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64 || C / H == 128);
if (C / H == 64) {
gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
} else {
gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,264 @@
#include "im2col.cuh"
#define MAX_GRIDDIM_Z 65535
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst,
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= IC_KH_KW) {
return;
}
const int64_t iic = i / (KH_KW);
const int64_t rem = i - iic * KH_KW;
const int64_t ikh = rem / KW;
const int64_t ikw = rem - ikh * KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
GGML_UNUSED(IC);
GGML_UNUSED(KH);
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
template <typename T>
static void im2col_cuda(const float * x, T* dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int64_t IC_KH_KW = IC * KH * KW;
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
const int64_t N_OH = N * OH;
const int64_t KH_KW = KW*KH;
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
s0, s1, p0, p1, d0, d1);
}
static void im2col_cuda_f16(const float * x, half * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
static void im2col_cuda_f32(const float * x, float * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} else {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= IC_KD_KH_KW) {
return;
}
GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
const int64_t iic = i / KD_KH_KW;
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
dst[offset_dst] = src[offset_src];
}
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static void im2col_3d_cuda(const float * src, T* dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_OH = N*OD*OH;
const int64_t OD_OH = OD*OH;
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
stride_q, stride_z, stride_y, stride_x,
s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
stride_q, stride_z, stride_y, stride_x,
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
static void im2col_3d_cuda_f32(const float * src, float * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
stride_q, stride_z, stride_y, stride_x,
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const size_t es = ggml_element_size(src1);
const int64_t stride_x = src1->nb[0] / es;
const int64_t stride_y = src1->nb[1] / es;
const int64_t stride_z = src1->nb[2] / es;
const int64_t stride_q = src1->nb[3] / es;
if(dst->type == GGML_TYPE_F16) {
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
stride_q, stride_z, stride_y, stride_x,
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
} else {
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
stride_q, stride_z, stride_y, stride_x,
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
}

View File

@@ -0,0 +1,6 @@
#include "common.cuh"
#define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,74 @@
#include "mean.cuh"
#include "reduce_rows.cuh"
#ifdef GGML_CUDA_USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB
template <typename T> __global__ void divide_by_count(T * result, size_t count) {
*result /= static_cast<T>(count);
}
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
// Special case for reducing vectors
#ifdef GGML_CUDA_USE_CUB
#ifdef USE_CUDA_GRAPH
cudaStreamCaptureStatus iscapturing;
CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
#endif // USE_CUDA_GRAPH
if ((nrows == 1) &&
#ifdef USE_CUDA_GRAPH
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled())) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled()))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH
// Single row - use device-wide reduction
size_t tmp_size = 0;
ggml_cuda_pool & pool = ctx.pool();
DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
// Divide by ncols
divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
return;
}
#endif // GGML_CUDA_USE_CUB
const dim3 block_nums(nrows, 1, 1);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
// Heuristic for block size selection to optimize occupancy.
// See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

1287
ggml/src/ggml-cuda/mma.cuh Normal file

File diff suppressed because it is too large Load Diff

171
ggml/src/ggml-cuda/mmf.cu Normal file
View File

@@ -0,0 +1,171 @@
#include "ggml.h"
#include "mmf.cuh"
#include "mmid.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mmf_ids_data ids_info{};
mmf_ids_data * ids_info_ptr = nullptr;
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
int64_t stride_channel_y = ids ? s11 : s12;
int64_t nchannels_y = ids ? ne11 : ne12;
//mul_mat_id: handle broadcast
if (ids && nchannels_y == 1) {
stride_channel_y = 0;
nchannels_y = ids->ne[0];
}
if (ids && ncols_dst > 16) {
const int64_t n_expert_used = ids->ne[0];
const int64_t n_experts = ne02;
const int64_t n_tokens = ne12;
const int64_t ne_get_rows = n_tokens * n_expert_used;
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
const int si1 = static_cast<int>(ids_s1);
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
GGML_ASSERT(sis1 > 0);
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
CUDA_CHECK(cudaGetLastError());
ids_info.ids_src_compact = ids_src_compact_dev.get();
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
ids_info.expert_bounds_dev = expert_bounds_dev.get();
ids_info.n_experts = static_cast<int>(n_experts);
ids_info.sis1 = sis1;
ids_info_ptr = &ids_info;
}
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
if (ggml_is_quantized(type)) {
return false;
}
const size_t ts = ggml_type_size(type);
if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
return false;
}
if (src0_nb[0] != ts) {
return false;
}
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (src0_nb[i] % (2*ts) != 0) {
return false;
}
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (mul_mat_id) {
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
return false;
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
return false;
}
} else {
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
return false;
} else if (src1_ncols > 16) {
return false;
}
}
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc) || amd_wmma_available(cc);
default:
return false;
}
}

835
ggml/src/ggml-cuda/mmf.cuh Normal file
View File

@@ -0,0 +1,835 @@
#pragma once
#include "mma.cuh"
#include "common.cuh"
#include "convert.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int32_t * ids_dst_compact = nullptr;
const int32_t * expert_bounds_dev = nullptr;
int n_experts = 0;
int sis1 = 0;
};
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // VOLTA_MMA_AVAILABLE
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
int expert_idx = 0;
int col_base = 0;
const int channel_dst = has_ids ? 0 : blockIdx.y;
if constexpr (has_ids) {
// experts + tiles of ncols_dst are packed in the y dimension
int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
const int nchannels_x = gridDim.y / col_tiles;
const int tile_idx = blockIdx.y / nchannels_x;
expert_idx = blockIdx.y - tile_idx * nchannels_x;
col_base = tile_idx * cols_per_block;
}
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
if constexpr (has_ids) {
constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
const int64_t col_offset = col_base;
y += col_offset * stride_col_y * y_stride_scale;
dst += col_offset * stride_col_dst;
ids += col_offset * stride_row_id;
}
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
char * shmem_base = data_mmv;
int * slot_map = (int *) shmem_base;
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
if constexpr (has_ids) {
int found = 0;
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (threadIdx.x == 0) {
slot_map[j] = -1;
}
if (col_base + j >= ncols_dst_total) {
continue;
}
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
int match = id_row[k*stride_col_id] == expert_idx;
if (match) {
slot_map[j] = k;
found = 1;
break;
}
}
}
if (!__syncthreads_or(found)) {
return;
}
}
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
} else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
} else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
}
//This kernel is for larger batch sizes of mul_mat_id
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // VOLTA_MMA_AVAILABLE
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = blockIdx.y;
const int expert_start = expert_bounds[expert_idx];
const int expert_end = expert_bounds[expert_idx + 1];
const int ncols_expert = expert_end - expert_start;
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
const int tile_idx = blockIdx.z;
if (tile_idx >= tiles_for_expert) {
return;
}
const int col_base = tile_idx * cols_per_block;
GGML_UNUSED(channel_ratio);
const int channel_x = expert_idx;
const int sample_dst = 0;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
y += int64_t(sample_y) *stride_sample_y;
dst += int64_t(sample_dst)*stride_sample_dst;
const int32_t * ids_src_expert = ids_src_compact + expert_start;
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
extern __shared__ char data_mmv[];
char * compute_base = data_mmv;
//const float2 * y2 = (const float2 *) y;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
if constexpr (std::is_same_v<T, float>) {
float vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float val = 0.0f;
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
val = y[channel*stride_channel_y + token*stride_col_y + col];
}
}
vals[j0] = val;
}
};
gather_tile(0, vals_buf[0]);
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
float2 vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
}
}
vals[j0] = tmp;
}
};
if (ntB > 0) {
gather_tile(0, vals_buf[0]);
}
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
const int global_j = col_base + j;
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
const int dst_entry = ids_dst_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
}
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
const mmf_ids_data * ids_data) {
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
// we prefer the normal mul_mat_f path with has_ids=true.
if (has_ids_data && ncols_dst > 16) {
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
if (max_tiles == 0) {
return;
}
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
sis1_fd, nch_fd);
} else if (ids) {
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A_16;
typedef tile<32, 8, T> tile_A_32;
typedef tile<16, 8, T> tile_B_16;
typedef tile< 8, 8, T> tile_B_8;
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
GGML_ASSERT(ids || ncols_dst <= 16);
switch (ncols_case) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);
DECL_MMF_CASE_EXTERN(3);
DECL_MMF_CASE_EXTERN(4);
DECL_MMF_CASE_EXTERN(5);
DECL_MMF_CASE_EXTERN(6);
DECL_MMF_CASE_EXTERN(7);
DECL_MMF_CASE_EXTERN(8);
DECL_MMF_CASE_EXTERN(9);
DECL_MMF_CASE_EXTERN(10);
DECL_MMF_CASE_EXTERN(11);
DECL_MMF_CASE_EXTERN(12);
DECL_MMF_CASE_EXTERN(13);
DECL_MMF_CASE_EXTERN(14);
DECL_MMF_CASE_EXTERN(15);
DECL_MMF_CASE_EXTERN(16);
#else
#define DECL_MMF_CASE(ncols_dst)
#endif

164
ggml/src/ggml-cuda/mmid.cu Normal file
View File

@@ -0,0 +1,164 @@
#include "common.cuh"
#include "mmid.cuh"
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mm_ids_helper_store {
uint32_t data;
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mm_ids_helper[];
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mm_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mm_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
void ggml_cuda_launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
switch (n_expert_used) {
case 2:
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 4:
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 6:
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 8:
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 16:
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 32:
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
default:
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
void ggml_cuda_launch_mm_ids_helper(
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);

366
ggml/src/ggml-cuda/mmq.cu Normal file
View File

@@ -0,0 +1,366 @@
#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
void ggml_cuda_mul_mat_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
GGML_TENSOR_BINARY_OP_LOCALS;
cudaStream_t stream = ctx.stream();
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
const char * src0_d = (const char *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
// If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
const size_t size_data = ggml_nbytes(src0);
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
if (size_alloc > size_data) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
}
}
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
CUDA_CHECK(cudaGetLastError());
}
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(nb12 % nb11 == 0);
GGML_ASSERT(nb2 % nb1 == 0);
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
GGML_ASSERT(ne1 == n_expert_used);
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
const int64_t ne11_flat = ne12*n_expert_used;
const int64_t ne12_flat = 1;
const int64_t ne13_flat = 1;
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc))
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
}
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
#ifdef GGML_CUDA_FORCE_CUBLAS
return false;
#endif // GGML_CUDA_FORCE_CUBLAS
bool mmq_supported;
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
mmq_supported = true;
break;
default:
mmq_supported = false;
break;
}
if (!mmq_supported) {
return false;
}
if (turing_mma_available(cc)) {
return true;
}
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
return false;
}
#ifdef GGML_CUDA_FORCE_MMQ
return true;
#endif //GGML_CUDA_FORCE_MMQ
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (n_experts > 64 || ne11 <= 128) {
return true;
}
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
return true;
}
return false;
}
if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
// High expert counts are almost always better on MMQ due to
// the synchronization overhead in the cuBLAS/hipBLAS path:
// https://github.com/ggml-org/llama.cpp/pull/18202
if (n_experts >= 64) {
return true;
}
// For some quantization types MMQ can have lower peak TOPS than hipBLAS
// so it's only faster for sufficiently small batch sizes:
switch (type) {
case GGML_TYPE_Q2_K:
return ne11 <= 128;
case GGML_TYPE_Q6_K:
return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
default:
return true;
}
}
// For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
return true;
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

4085
ggml/src/ggml-cuda/mmq.cuh Normal file

File diff suppressed because it is too large Load Diff

802
ggml/src/ggml-cuda/mmvf.cu Normal file
View File

@@ -0,0 +1,802 @@
#include "ggml.h"
#include "common.cuh"
#include "unary.cuh"
#include "mmvf.cuh"
#include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x;
const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
const T * gate_x = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr;
glu_op = fusion.glu_op;
if (use_gate) {
gate_x = static_cast<const T *>(fusion.gate);
}
if (use_bias) {
x_bias = static_cast<const float *>(fusion.x_bias);
}
if (use_gate_bias) {
gate_bias = static_cast<const float *>(fusion.gate_bias);
use_gate_bias = use_gate;
} else {
use_gate_bias = false;
}
}
if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
}
if constexpr (has_fusion) {
const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
if (use_gate_bias) {
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
}
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;
float * buf_iw_gate = nullptr;
if constexpr (has_fusion) {
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
}
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid] = 0.0f;
}
}
}
__syncthreads();
}
float sumf[ncols_dst] = {0.0f};
float sumf_gate[ncols_dst];
if constexpr (has_fusion) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = 0.0f;
}
}
if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x;
const float2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const float2 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2];
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
} else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x;
const half2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const half2 *) gate_x;
}
}
if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = __half22float2(gate_x2[col2]);
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
} else {
#ifdef FP16_AVAILABLE
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const half2 tmpx = x2[col2];
half2 tmpx_gate = make_half2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
}
}
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
}
if constexpr (has_fusion) {
if (use_gate) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
}
}
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x;
const int * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const int *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
int tmpx_gate = 0;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
}
}
}
}
#else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
const nv_bfloat162 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const nv_bfloat162 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2];
nv_bfloat162 tmpx_gate;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
#endif
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf[j];
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid/warp_size] = sumf_gate[j];
}
}
__syncthreads();
if (tid < warp_size) {
sumf[j] = buf_iw[tid];
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = buf_iw_gate[tid];
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
}
if (j < ncols_dst) {
__syncthreads();
}
}
}
if (tid >= ncols_dst) {
return;
}
float value = sumf[tid];
if constexpr (has_fusion) {
if (use_bias) {
value += x_bias[tid*stride_col_dst + row];
}
if (use_gate) {
float gate_value = sumf_gate[tid];
if (use_gate_bias) {
gate_value += gate_bias[tid*stride_col_dst + row];
}
switch (glu_op) {
case GGML_GLU_OP_SWIGLU:
value *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
value *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
break;
}
default:
break;
}
}
}
dst[tid*stride_col_dst + row] = value;
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
template <typename T, typename type_acc, int ncols_dst>
void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
int64_t max_block_size = 256;
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
max_block_size = 128;
}
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
block_size_best = block_size;
}
}
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 64: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 96: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 128: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 160: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 192: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 224: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 256: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T, typename type_acc>
static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
template<typename T>
static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00;
const int64_t stride_col_y = ne10;
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1;
const int64_t stride_channel_x = 0;
const int64_t stride_channel_y = 0;
const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
const int64_t nsamples_dst = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
ggml_cuda_mm_fusion_args_device empty{};
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
}
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
const size_t ts = ggml_type_size(type);
if (src0_nb[0] != ts) {
return false;
}
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (src0_nb[i] % (2*ts) != 0) {
return false;
}
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (ampere_mma_available(cc)) {
return ne11 <= 3;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
}
return ne11 <= 3;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (fp32_mma_hardware_available(cc)) {
return ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (fp16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (fp16_mma_hardware_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return ne11 <= 3;
}
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return ne11 <= 5;
}
return ne11 <= 2;
}
return ne11 <= 8;
}
return ne11 <= 8;
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (bf16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (bf16_mma_hardware_available(cc)) {
return ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
default:
return false;
}
}

View File

@@ -0,0 +1,12 @@
#include "common.cuh"
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);

732
ggml/src/ggml-cuda/mmvq.cu Normal file
View File

@@ -0,0 +1,732 @@
#include "mmvq.cuh"
#include "quantize.cuh"
#include "unary.cuh"
#include "vecdotq.cuh"
#include <cstdint>
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
default: return nullptr;
}
}
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
default: return 1;
}
}
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
#else
return MMVQ_PARAMETERS_GENERIC;
#endif
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return MMVQ_PARAMETERS_GCN;
}
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
case 1:
case 2:
case 3:
case 4:
return 4;
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
} else if (table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
case 2:
case 3:
case 4:
return 2;
case 5:
case 6:
case 7:
case 8:
default:
return 1;
}
}
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
return 1;
case 2:
case 3:
case 4:
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
}
return 1;
}
// tell the compiler to use as many registers as it wants, see nwarps definition below
template <ggml_type type, int ncols_dst, bool has_fusion>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const uint32_t channel_dst = blockIdx.y;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
const uint32_t sample_dst = blockIdx.z;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
const void * vgate = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
ggml_glu_op active_glu;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
vgate = fusion.gate;
x_bias = (const float *) fusion.x_bias;
gate_bias = (const float *) fusion.gate_bias;
active_glu = fusion.glu_op;
}
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
// 2. load only on threads that won't die after partial sum calculation
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
}
}
}
if (use_gate_bias) {
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
}
}
}
}
// partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
// x block quant index when casting the quants to int
const int kqs = vdr * (tid % (qi/vdr));
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += vec_dot_q_cuda(
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
}
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if constexpr (!has_fusion) {
(void) tmp_shared_gate;
} else if (!use_gate) {
(void) tmp_shared_gate;
}
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
if constexpr (has_fusion) {
if (use_gate) {
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
}
}
}
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
#pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
}
}
}
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
}
}
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
float result = tmp[j][threadIdx.x];
if constexpr (has_fusion) {
if (use_bias) {
result += x_biases[j];
}
if (use_gate) {
float gate_value = tmp_gate[j][threadIdx.x];
if (use_gate_bias) {
gate_value += gate_biases[j];
}
switch (active_glu) {
case GGML_GLU_OP_SWIGLU:
result *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
result *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
break;
}
default:
result = result * gate_value;
break;
}
}
}
dst[j*stride_col_dst + threadIdx.x] = result;
}
}
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
}
}
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
default:
GGML_ABORT("fatal error");
break;
}
GGML_UNUSED(has_fusion);
}
static void mul_mat_vec_q_switch_type(
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
void ggml_cuda_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
GGML_TENSOR_BINARY_OP_LOCALS;
cudaStream_t stream = ctx.stream();
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
// If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
const size_t size_data = ggml_nbytes(src0);
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
if (size_alloc > size_data) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
}
}
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
}
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = ne10_padded / QK8_1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;
const int64_t s12 = ne11*s11;
const int64_t s13 = ne12*s12;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0);
const int64_t ne0 = dst->ne[0];
int id = ggml_cuda_get_device();
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
const int stride_col_y = src1_padded_row_size / QK8_1;
ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
}

View File

@@ -0,0 +1,12 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);

672
ggml/src/ggml-cuda/norm.cu Normal file
View File

@@ -0,0 +1,672 @@
#include "norm.cuh"
#include <cstdint>
template <int block_size>
static __global__ void norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
extern __shared__ float2 s_sum2[];
mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = (x[col] - mean) * inv_std;
}
}
template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
const int start = blockIdx.x*group_size + threadIdx.x;
const int end = min(blockIdx.x*group_size + group_size, ne_elements);
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += block_size) {
tmp += x[j];
}
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / group_size;
tmp = 0.0f;
for (int j = start; j < end; j += block_size) {
const float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
dst[j] *= scale;
}
}
template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32(const float * x,
float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
}
if constexpr (do_add) {
const int add_row = fastmodulo(row, add_nrows_packed);
const int add_channel = fastmodulo(channel, add_nchannels_packed);
const int add_sample = fastmodulo(sample, add_nsamples_packed);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply && do_add) {
const int mul_col = fastmodulo(col, mul_ncols_packed);
const int add_col = fastmodulo(col, add_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
const int mul_col = fastmodulo(col, mul_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
}
}
}
template <int block_size>
static __global__ void rms_norm_back_f32(
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
grad += int64_t(row)*ncols;
xf += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
for (int col = tid; col < ncols; col += block_size) {
const float xfi = xf[col];
sum_xx += xfi * xfi;
sum_xg += xfi * grad[col];
}
// sum up partial sums
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = warp_reduce_sum(sum_xg);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum_xx[32];
__shared__ float s_sum_xg[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum_xx[warp_id] = sum_xx;
s_sum_xg[warp_id] = sum_xg;
}
__syncthreads();
sum_xx = s_sum_xx[lane_id];
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = s_sum_xg[lane_id];
sum_xg = warp_reduce_sum(sum_xg);
}
const float mean_eps = sum_xx / ncols + eps;
const float sum_eps = sum_xx + ncols*eps;
const float scale_grad = rsqrtf(mean_eps);
const float scale_x = -scale_grad * sum_xg/sum_eps;
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale_grad*grad[col] + scale_x*xf[col];
}
}
// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
// const int tid = threadIdx.x;
// float tmp = 0.0f; // partial sum for thread in warp
// for (int col = tid; col < ncols; col += block_size) {
// const float xi = x[row*ncols + col];
// tmp += xi * xi;
// }
// // sum up partial sums
// tmp = warp_reduce_sum(tmp);
// if (block_size > WARP_SIZE) {
// __shared__ float s_sum[32];
// int warp_id = threadIdx.x / WARP_SIZE;
// int lane_id = threadIdx.x % WARP_SIZE;
// if (lane_id == 0) {
// s_sum[warp_id] = tmp;
// }
// __syncthreads();
// tmp = s_sum[lane_id];
// tmp = warp_reduce_sum(tmp);
// }
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
// for (int col = tid; col < ncols; col += block_size) {
// dst[row*ncols + col] = scale * x[row*ncols + col];
// }
// }
template <int block_size>
static __global__ void l2_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
static void group_norm_f32_cuda(
const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
static void rms_norm_mul_f32_cuda(const float * x,
const float * mul,
const float * add,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const uint32_t mul_ncols,
const uint32_t mul_nrows,
const uint32_t mul_nchannels,
const uint32_t mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const uint32_t add_ncols,
const uint32_t add_nrows,
const uint32_t add_nchannels,
const uint32_t add_nsamples,
const float eps,
cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (add == nullptr) {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
} else {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
}
}
}
static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
}
}
static void l2_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
}
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if(mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) mul_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
ne00, ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ 0, 0, 0,
0, 0, 0, 0,
eps, stream);
}
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if (mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
const float * add_d = nullptr;
const ggml_tensor * add_src = nullptr;
if (add_tensor->src[0] == mul_tensor) {
add_d = (float *) add_tensor->src[1]->data;
add_src = add_tensor->src[1];
} else if (add_tensor->src[1] == mul_tensor) {
add_d = (float *) add_tensor->src[0]->data;
add_src = add_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) add_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
const size_t ts_add = ggml_type_size(add_src->type);
GGML_ASSERT(add_src->nb[0] == ts_add);
const int64_t add_s01 = add_src->nb[1] / ts_add;
const int64_t add_s02 = add_src->nb[2] / ts_add;
const int64_t add_s03 = add_src->nb[3] / ts_add;
const int add_ncols = add_src->ne[0];
const int add_nrows = add_src->ne[1];
const int add_nchannels = add_src->ne[2];
const int add_nsamples = add_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
ne00,ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ add_s01, add_s02, add_s03,
add_ncols, add_nrows, add_nchannels, add_nsamples,
eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
const float * grad_d = (const float *) grad->data;
const float * src0f_d = (const float *) src0f->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0f);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}

View File

@@ -0,0 +1,18 @@
#include "common.cuh"
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,78 @@
#include "ggml-impl.h"
#include "opt-step-adamw.cuh"
#include <cstdint>
static __global__ void opt_step_adamw_f32(
float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
const float * __restrict__ pars, const int64_t k) {
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
if (i >= k) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[i];
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
g_m[i] = gmi;
g_v[i] = gvi;
const float mh = gmi*beta1h;
const float vh = sqrtf(gvi*beta2h) + eps;
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
}
static void opt_step_adamw_f32_cuda(
float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);
}
void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0_grad = dst->src[1];
const ggml_tensor * src0_grad_m = dst->src[2];
const ggml_tensor * src0_grad_v = dst->src[3];
const ggml_tensor * adamw_params = dst->src[4];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src0_grad));
GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
GGML_ASSERT(ggml_is_contiguous(adamw_params));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
float * src0_d = (float *) src0->data;
const float * src0_grad_d = (const float *) src0_grad->data;
float * src0_grad_m_d = (float *) src0_grad_m->data;
float * src0_grad_v_d = (float *) src0_grad_v->data;
const float * adamw_params_d = (const float *) adamw_params->data;
cudaStream_t stream = ctx.stream();
const int64_t ne = ggml_nelements(src0);
opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,49 @@
#include "ggml-impl.h"
#include "opt-step-sgd.cuh"
#include <cstdint>
static __global__ void opt_step_sgd_f32(
float * __restrict__ x, const float * __restrict__ g,
const float * __restrict__ pars, const int64_t k) {
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
if (i >= k) {
return;
}
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
}
static void opt_step_sgd_f32_cuda(
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
}
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0_grad = dst->src[1];
const ggml_tensor * params = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
GGML_ASSERT(params->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src0_grad));
GGML_ASSERT(ggml_is_contiguous(params));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
GGML_ASSERT(ggml_nelements(params) == 2);
float * src0_d = (float *) src0->data;
const float * src0_grad_d = (const float *) src0_grad->data;
const float * params_d = (const float *) params->data;
cudaStream_t stream = ctx.stream();
const int64_t ne = ggml_nelements(src0);
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,68 @@
#include "out-prod.cuh"
#include <cstdint>
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ne01 == ne11);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 % src0->ne[2] == 0);
GGML_ASSERT(ne3 % src0->ne[3] == 0);
GGML_ASSERT(ne2 == src1->ne[2]);
GGML_ASSERT(ne3 == src1->ne[3]);
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
cublasHandle_t handle = ctx.cublas_handle();
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(cublasSetStream(handle, stream));
const int64_t lda = nb01 / sizeof(float);
const int64_t ldc = nb1 / sizeof(float);
const bool src1_T = ggml_is_transposed(src1);
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
// data strides in dimensions 2/3
const size_t s02 = nb02 / sizeof(float);
const size_t s03 = nb03 / sizeof(float);
const size_t s12 = nb12 / sizeof(float);
const size_t s13 = nb13 / sizeof(float);
const size_t s2 = nb2 / sizeof(float);
const size_t s3 = nb3 / sizeof(float);
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
// TODO batched matrix multiplication
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
}

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

103
ggml/src/ggml-cuda/pad.cu Normal file
View File

@@ -0,0 +1,103 @@
#include "pad.cuh"
#include <stdint.h>
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
// + size ensures negatives are handled properly
return (coord + size) % size;
}
static __global__ void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3,
const bool circular) {
// blockIdx.z: i3*ne2+i2
// blockIdx.y: i1
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
// gridDim.y: ne1
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
int i1 = blockIdx.y;
int i2 = blockIdx.z % ne2;
int i3 = blockIdx.z / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
if (!circular) {
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = 0.0f;
}
}
// circular means on a torus, so x and y wrap around
else {
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne03 = ne3 - lp3 - rp3;
const int64_t i00 = wrap_around(i0 - lp0, ne00);
const int64_t i01 = wrap_around(i1 - lp1, ne01);
const int64_t i02 = wrap_around(i2 - lp2, ne02);
const int64_t i03 = wrap_around(i3 - lp3, ne03);
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
}
}
static void pad_f32_cuda(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3,
const bool circular, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
ne0, ne1, ne2, ne3, circular);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
pad_f32_cuda(src0_d, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
(bool) circular, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_PAD_BLOCK_SIZE 256
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,91 @@
#include "pad_reflect_1d.cuh"
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
pad_reflect_1d_kernel_f32(
const void * __restrict__ src0,
void * __restrict__ dst,
const int64_t ne0,
const int64_t ne00,
const uint3 ne01,
const int64_t ne02,
const int64_t ne03,
const int64_t nb00,
const int64_t nb01,
const int64_t nb02,
const int64_t nb03,
const int64_t nb0,
const int64_t nb1,
const int64_t nb2,
const int64_t nb3,
const int p0,
const int p1) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
const int64_t tile1 = div_mod_packed.y; // i1
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
const int64_t i1 = tile1;
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
return;
}
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;
if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *) (src0_ptr + src_idx * nb00);
*(float *) (dst_ptr + i0 * nb0) = value;
GGML_UNUSED(p1);
}
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne0 = dst->ne[0];
// sanity: padded length matches
GGML_ASSERT(ne0 == ne00 + p0 + p1);
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
// grid.y covers i2: [ne02]
// grid.z covers i3: [ne03]
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
const dim3 block_dims((unsigned) bx, 1, 1);
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,94 @@
#include "pool2d.cuh"
template <typename Ti, typename To>
static __global__ void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = max(0, start_h);
const int eh = min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = max(0, start_w);
const int ew = min(iw, start_w + kw);
const To scale = 1. / (kh * kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default: assert(false);
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if __CUDA_ARCH__ >= 350
Ti cur = __ldg(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += cur * scale; break;
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
default: assert(false);
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
static void pool2d_nchw_kernel_f32_f32_cuda(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const float * src, float * dst, const enum ggml_op_pool op,
cudaStream_t stream) {
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
}
void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_POOL2D_BLOCK_SIZE 256
void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,343 @@
#include "quantize.cuh"
#include <cstdint>
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) {
return;
}
const int64_t i3 = fastdiv(blockIdx.z, ne2);
const int64_t i2 = blockIdx.z - i3*ne2.z;
const int64_t i1 = blockIdx.y;
const int64_t & i00 = i0;
const int64_t & i01 = i1;
const int64_t & i02 = i2;
const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy;
const int64_t ib = i_cont / QK8_1; // block index
const int64_t iqs = i_cont % QK8_1; // quant index
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
amax = warp_reduce_max<QK8_1>(amax);
sum = warp_reduce_sum<QK8_1>(sum);
const float d = amax / 127.0f;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
y[ib].ds = make_half2(d, sum);
}
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (!(amax > 0.0f)) {
return 0;
}
// FP4 E2M1: max exponent (unbiased) is 2.
constexpr int FP4_E2M1_EMAX = 2;
const float e = log2f(amax);
// "even" -> round-to-nearest integer, ties-to-even
const int e_int = __float2int_rn(e);
const int shared_exp = e_int - FP4_E2M1_EMAX;
int biased = shared_exp + 127;
biased = max(biased, 0);
biased = min(biased, 254);
return static_cast<uint8_t>(biased);
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids,
void * __restrict__ vy,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int ne1,
const int ne2) {
constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x;
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
if (warp_start_offset >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int group_id = lane_id_32 / 4;
const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs;
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
uint8_t scales[2];
#pragma unroll
for (int b = 0; b < 2; ++b) {
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12080
const float scaled_val = xi * inv_s;
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
}
#else
// Fallback: manual FP4 conversion using LUT
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
char2 q;
q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
}
#endif // CUDART_VERSION >= 12080
}
if (lane_id_32 == 0) {
// Store 2 scales packed into 1 uint32
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
}
}
template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
if (i0 >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i00 = i0;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
const float4 * x4 = (const float4 *) x;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them:
const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
}
float sum;
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;
// Calculate sums across vals_per_sum/4 threads.
#pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
}
}
const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return;
}
y[ib].d2s6[2 + iqs/16] = sum;
if (iqs % 64 != 0) {
return;
}
const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d;
return;
}
if (iqs % 32 != 0) {
return;
}
const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else {
y[ib].d4[iqs/32] = d;
}
}
void quantize_row_q8_1_cuda(
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(!ids);
GGML_ASSERT(ne0 % QK8_1 == 0);
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
GGML_UNUSED(type_src0);
}
void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}

View File

@@ -0,0 +1,41 @@
#pragma once
#include "common.cuh"
#include "mmq.cuh"
#include <cstdint>
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_row_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,
int64_t ne00,
int64_t s01,
int64_t s02,
int64_t s03,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
cudaStream_t stream);

View File

@@ -0,0 +1,39 @@
#include "common.cuh"
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template <bool norm>
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
const int num_unroll = 8;
float temp[num_unroll];
float sum_temp[num_unroll] = { 0.0f };
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {
temp[j] = x[row * ncols + i];
} else {
temp[j] = 0;
}
i += blockDim.x;
}
for (int j = 0; j < num_unroll; ++j) {
sum_temp[j] += temp[j];
}
}
for (int j = 0; j < num_unroll; ++j) {
sum += sum_temp[j];
}
// sum up partial sums
__shared__ float shared_vals[32];
sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}

View File

@@ -0,0 +1,67 @@
#include "ggml-cuda/common.cuh"
#include "roll.cuh"
static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
if (idx < 0) {
return idx + ne;
}
if (idx >= ne) {
return idx - ne;
}
return idx;
}
static __global__ void roll_f32_cuda(const float * __restrict__ src,
float * __restrict__ dst,
const int64_t ne00,
const int64_t ne01,
const int64_t ne02,
const int64_t ne03,
const int s0,
const int s1,
const int s2,
const int s3) {
const int64_t idx = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t n_elements = ne00 * ne01 * ne02 * ne03;
if (idx >= n_elements) {
return;
}
const int64_t i0 = idx % ne00;
const int64_t i1 = (idx / ne00) % ne01;
const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;
const int64_t d0 = wrap_index(i0 - s0, ne00);
const int64_t d1 = wrap_index(i1 - s1, ne01);
const int64_t d2 = wrap_index(i2 - s2, ne02);
const int64_t d3 = wrap_index(i3 - s3, ne03);
dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
}
void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
int s0 = dst->op_params[0];
int s1 = dst->op_params[1];
int s2 = dst->op_params[2];
int s3 = dst->op_params[3];
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) dst->src[0]->data;
float * dst_d = (float *) dst->data;
GGML_TENSOR_UNARY_OP_LOCALS;
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
cudaStream_t stream = ctx.stream();
int64_t sz = (ne00 * ne01 * ne02 * ne03);
int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;
roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(
src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ROLL_BLOCK_SIZE 256
void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

565
ggml/src/ggml-cuda/rope.cu Normal file
View File

@@ -0,0 +1,565 @@
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"
#include "rope.cuh"
struct rope_corr_dims {
float v[2];
};
struct mrope_sections {
int v[4];
};
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
template<bool forward>
static __device__ void rope_yarn(
const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
float mscale, float & cos_theta, float & sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
cos_theta = cosf(theta) * mscale;
sin_theta = sinf(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
}
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0;
idst += row_indices[channel_x] * set_rows_stride;
}
const auto & store_coaelsced = [&](float x0, float x1) {
if constexpr (std::is_same_v<float, D>) {
float2 v = make_float2(x0, x1);
ggml_cuda_memcpy_1<8>(dst + idst, &v);
} else if constexpr (std::is_same_v<half, D>) {
half2 v = make_half2(x0, x1);
ggml_cuda_memcpy_1<4>(dst + idst, &v);
}
};
if (i0 >= n_dims) {
store_coaelsced(x[ix + 0], x[ix + 1]);
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];
store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
}
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int idst = row_dst * ne0 + i0 / 2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride;
}
if (i0 >= n_dims) {
dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
}
template <bool forward, typename T, typename D>
static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
}
}
template <bool forward, typename T, typename D>
static void rope_neox_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
}
}
template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
template<bool forward, typename T>
static void rope_vision_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}
template <bool forward>
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
const ggml_tensor * set_rows = nullptr) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
void * dst_d = dst->data;
const int64_t * row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;
if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *) set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
// When not fused, src0 and dst types must match
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
const int32_t * pos = (const int32_t *) src1_d;
const float * freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
}
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<true>(ctx, dst);
}
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst);
}
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
}

View File

@@ -0,0 +1,9 @@
#include "common.cuh"
#define CUDA_ROPE_BLOCK_SIZE 256
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);

View File

@@ -0,0 +1,34 @@
#include "scale.cuh"
#define MAX_GRIDDIM_X 0x7FFFFFFF
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
for (int64_t i = tid; i < nelements; i += stride) {
dst[i] = scale * x[i] + bias;
}
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float scale;
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
}

View File

@@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_SCALE_BLOCK_SIZE 256
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,330 @@
#include "set-rows.cuh"
#include "cpy-utils.cuh"
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
// Generic quantized set_rows kernel template
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
const idx_t * __restrict__ src1,
block_type * __restrict__ dst,
const int64_t ne_total,
const int64_t ne10,
const int64_t ne11,
const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (i >= ne_total) {
return;
}
const int64_t i_base = i * qk;
uint32_t tmp = (uint32_t) i_base;
uint2 div_mod;
div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne01);
const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
const float * src_block = src0_row + i00;
block_type * dst_block = dst_row_ptr + i00 / qk;
quantize_func(src_block, dst_block);
GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13);
}
// Template dispatch function for quantized set_rows
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static void set_rows_cuda_quant(
const float * src0_d, const idx_t * src1_d, block_type * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
GGML_ASSERT(ne00 % qk == 0);
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(float);
const int64_t s02 = nb02/sizeof(float);
const int64_t s03 = nb03/sizeof(float);
const int64_t s10 = nb10/sizeof(idx_t);
const int64_t s11 = nb11/sizeof(idx_t);
const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1;
const int64_t s2 = nb2;
const int64_t s3 = nb3;
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
}
}
template <typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows(const src_t * __restrict__ src0,
const idx_t * __restrict__ src1,
dst_t * __restrict__ dst,
const int64_t ne_total,
const int64_t ne10,
const int64_t ne11,
const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
if (i >= ne_total) {
return;
}
uint32_t tmp = (uint32_t) i;
uint2 div_mod;
div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne01);
const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13);
}
template<typename src_t, typename idx_t, typename dst_t>
static void set_rows_cuda(
const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(src_t);
const int64_t s02 = nb02/sizeof(src_t);
const int64_t s03 = nb03/sizeof(src_t);
const int64_t s10 = nb10/sizeof(idx_t);
const int64_t s11 = nb11/sizeof(idx_t);
const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1/sizeof(dst_t);
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
ne11_fd, ne12_fd);
}
}
template<typename src_t, typename idx_t>
static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const src_t * src0_d = (const src_t *)src0->data;
const idx_t * src1_d = (const idx_t *)src1->data;
GGML_TENSOR_BINARY_OP_LOCALS
cudaStream_t stream = ctx.stream();
if (dst->type == GGML_TYPE_F32) {
set_rows_cuda(
src0_d, src1_d, (float*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_F16) {
set_rows_cuda(
src0_d, src1_d, (half*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_BF16) {
set_rows_cuda(
src0_d, src1_d, (nv_bfloat16*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_0) {
set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
src0_d, src1_d, (block_q4_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_1) {
set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
src0_d, src1_d, (block_q4_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_0) {
set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
src0_d, src1_d, (block_q5_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_1) {
set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
src0_d, src1_d, (block_q5_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q8_0) {
set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
src0_d, src1_d, (block_q8_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_IQ4_NL) {
set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
src0_d, src1_d, (block_iq4_nl*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else {
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
}
}
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
if (src1->type == GGML_TYPE_I64) {
set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
} else {
set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
}
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include "common.cuh"
#define CUDA_SET_ROWS_BLOCK_SIZE 256
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

39
ggml/src/ggml-cuda/set.cu Normal file
View File

@@ -0,0 +1,39 @@
#include "set.cuh"
#include "cpy.cuh"
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
GGML_ASSERT(src1->type == src0->type);
GGML_ASSERT(dst ->type == src0->type);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t nb1 = ((int32_t *) dst->op_params)[0];
const size_t nb2 = ((int32_t *) dst->op_params)[1];
const size_t nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
ggml_cuda_cpy(ctx, src0, dst);
}
ggml_tensor dst_view = *dst;
dst_view.data = (void *)((char *)dst->data + offset);
dst_view.ne[0] = src1->ne[0];
dst_view.ne[1] = src1->ne[1];
dst_view.ne[2] = src1->ne[2];
dst_view.ne[3] = src1->ne[3];
dst_view.nb[0] = ggml_element_size(dst);
dst_view.nb[1] = nb1;
dst_view.nb[2] = nb2;
dst_view.nb[3] = nb3;
ggml_cuda_cpy(ctx, src1, &dst_view);
}

Some files were not shown because too many files have changed in this diff Show More