sync from b7516

This commit is contained in:
2026-01-16 11:16:14 +08:00
parent f4ae4cc7da
commit 6ee41dd9e3
380 changed files with 18435 additions and 38806 deletions

View File

@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 5)
set(GGML_VERSION_PATCH 4)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@@ -430,22 +430,10 @@ if (MSVC)
configure_msvc_target(ggml-cpu-x64)
configure_msvc_target(ggml-cpu-sse42)
configure_msvc_target(ggml-cpu-sandybridge)
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
# skipping ggml-cpu-ivybridge
# skipping ggml-cpu-piledriver
configure_msvc_target(ggml-cpu-haswell)
configure_msvc_target(ggml-cpu-skylakex)
configure_msvc_target(ggml-cpu-cannonlake)
configure_msvc_target(ggml-cpu-cascadelake)
configure_msvc_target(ggml-cpu-icelake)
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
# skipping ggml-cpu-cooperlake
# skipping ggml-cpu-zen4
configure_msvc_target(ggml-cpu-alderlake)
# MSVC doesn't support AMX
# skipping ggml-cpu-sapphirerapids
if (GGML_BUILD_EXAMPLES)
configure_msvc_target(common-ggml)

View File

@@ -358,7 +358,7 @@ extern "C" {
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
// Tensor initialization
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

View File

@@ -234,11 +234,6 @@
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#elif defined(__EMSCRIPTEN__)
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
#define GGML_MEM_ALIGN 8
#else
#define GGML_MEM_ALIGN 16
#endif

View File

@@ -357,29 +357,15 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
if (GGML_SYSTEM_ARCH STREQUAL "x86")
ggml_add_cpu_backend_variant(x64)
ggml_add_cpu_backend_variant(sse42 SSE42)
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
if (NOT MSVC)
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
endif()
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
if (NOT MSVC)
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
endif()
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
ggml_add_cpu_backend_variant(sse42 SSE42)
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
@@ -401,8 +387,8 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
elseif (APPLE)
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)

View File

@@ -144,7 +144,7 @@ extern "C" {
// device description: short informative description of the device, could be the model name
const char * (*get_description)(ggml_backend_dev_t dev);
// device memory in bytes: 0 bytes to indicate no memory to report
// device memory in bytes
void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
// device type

View File

@@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
ggml_free(copy.ctx_unallocated);
}
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) {
return false;
@@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
assert(g1->n_nodes == g2->n_nodes);
if (num_test_nodes != 0) {
GGML_ASSERT(test_nodes);
// Compute the whole graph and only test the output for specific tensors
if (test_node != nullptr) {
// Compute the whole graph and only test the output for a specific tensor
ggml_backend_graph_compute(backend1, g1);
ggml_backend_graph_compute(backend2, g2);
bool verified = false;
int test_node_idx = -1;
for (int i = 0; i < g1->n_nodes; i++) {
for (size_t j = 0; j < num_test_nodes; ++j) {
if (g1->nodes[i] == test_nodes[j]) {
callback(i, g1->nodes[i], g2->nodes[i], user_data);
verified = true;
}
struct ggml_tensor * t1 = g1->nodes[i];
if (t1 == test_node) {
test_node_idx = i;
break;
}
}
GGML_ASSERT(verified);
GGML_ASSERT(test_node_idx != -1);
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
} else {
for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i];

View File

@@ -32,12 +32,14 @@ if (BLAS_FOUND)
pkg_check_modules(DepBLAS openblas)
endif()
elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
add_compile_definitions(GGML_BLAS_USE_BLIS)
pkg_check_modules(DepBLAS blis)
elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
pkg_check_modules(DepBLAS blas-atlas)
elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
pkg_check_modules(DepBLAS flexiblas_api)
elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
add_compile_definitions(GGML_BLAS_USE_MKL)
# all Intel* libraries share the same include path
pkg_check_modules(DepBLAS mkl-sdl)
elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
@@ -72,26 +74,10 @@ if (BLAS_FOUND)
target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
if ("${GGML_BLAS_VENDOR}" STREQUAL "")
message(WARNING "GGML_BLAS_VENDOR is not set; some methods may not link properly.")
endif()
if ("${GGML_BLAS_VENDOR}" MATCHES "Intel" OR ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND "${GGML_BLAS_VENDOR}" MATCHES "Generic"))
if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
add_compile_definitions(GGML_BLAS_USE_MKL)
endif()
if ("${GGML_BLAS_VENDOR}" MATCHES "OpenBLAS")
add_compile_definitions(GGML_BLAS_USE_OPENBLAS)
endif()
if ("${GGML_BLAS_VENDOR}" MATCHES "FLAME" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL_mt")
add_compile_definitions(GGML_BLAS_USE_BLIS)
endif()
if ("${GGML_BLAS_VENDOR}" MATCHES "NVPL")
add_compile_definitions(GGML_BLAS_USE_NVPL)
endif()
target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
else()

View File

@@ -115,11 +115,15 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
#endif
}
#if defined(GGML_BLAS_USE_OPENBLAS)
#if defined(OPENBLAS_VERSION)
openblas_set_num_threads(ctx->n_threads);
#elif defined(GGML_BLAS_USE_BLIS)
#endif
#if defined(GGML_BLAS_USE_BLIS)
bli_thread_set_num_threads(ctx->n_threads);
#elif defined(GGML_BLAS_USE_NVPL)
#endif
#if defined(GGML_BLAS_USE_NVPL)
nvpl_blas_set_num_threads(ctx->n_threads);
#endif
@@ -284,7 +288,7 @@ ggml_backend_t ggml_backend_blas_init(void) {
/* .context = */ ctx,
};
#if defined(GGML_BLAS_USE_OPENBLAS) && defined(GGML_USE_OPENMP)
#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
if (openblas_get_parallel() != OPENBLAS_OPENMP) {
GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
}
@@ -325,7 +329,7 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t
return "BLIS";
#elif defined(GGML_BLAS_USE_NVPL)
return "NVPL";
#elif defined(GGML_BLAS_USE_OPENBLAS)
#elif defined(OPENBLAS_VERSION)
return "OpenBLAS";
#else
return "BLAS";

View File

@@ -26,7 +26,6 @@
#include "ggml.h"
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_add_rms_norm.h>
#include <aclnnop/aclnn_addcdiv.h>
#include <aclnnop/aclnn_argmax.h>
#include <aclnnop/aclnn_avgpool2d.h>
@@ -1963,7 +1962,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
acl_tensor_ptr acl_weight_tensor;
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
@@ -2339,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
// TODO: acl_yarn_ramp_tensor use rope cache.
bool yarn_ramp_tensor_updated = false;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2381,10 +2380,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
} else {
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
}
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
@@ -2991,156 +2988,32 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
}
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
// stride
int64_t s0 = ((const int32_t*)(dst->op_params))[0];
int64_t s0 = ((const int32_t *) (dst->op_params))[0];
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
// get base information of input and kernel
int64_t input_len = *(src1->ne);
int64_t dst_len = *(dst->ne);
int64_t kernel_size = *(src0->ne);
// set the max kernel size for each conv
int64_t max_kernel_size = 255;
// compute the partition of kernel
int64_t part_num = 1;
part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
int64_t strideVal[1];
strideVal[0] = s0;
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = {0};
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = {1};
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
bool transposed = true;
int64_t groups = 1;
int8_t cubeMathType = 0;
strideVal[0] = s0;
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = { 0 };
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = { 1 };
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
cubeMathType = 1;
#endif
auto weight_type = ggml_cann_type_mapping(src0->type);
auto dst_type = ggml_cann_type_mapping(dst->type);
// slice the kernel to make each conv available
int64_t slice_dim = -1;
int64_t slice_start = 0;
int64_t slice_end = max_kernel_size;
int64_t slice_step = 1;
int64_t interval = max_kernel_size;
int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
int64_t right_pad_len = 0;
acl_scalar_ptr alpha = nullptr;
float alphaValue = 1.0;
alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
// set zero to destination
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
for(int k = 0; k < part_num; k++){
// create part kernel tensor and slice from big kernel
slice_start = max_kernel_size * k;
if(k == part_num - 1){
slice_end = kernel_size;
interval = kernel_size - max_kernel_size * k;
}else{
slice_end = max_kernel_size * (k+1);
}
int64_t part_ne[4];
for(int i = 0; i < 4; i++) {
part_ne[i] = *(src0->ne + i);
}
part_ne[0] = interval;
size_t part_nb[4];
part_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
part_nb[i] = part_nb[i - 1] * part_ne[i - 1];
}
ggml_cann_pool_alloc part_kernel_allocator;
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
void* part_kernel_buf = part_kernel_allocator.get();
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
// create the part conv result tensor
int64_t part_dst_ne[4];
for(int i = 0; i < 4; i++){
part_dst_ne[i] = *(dst->ne + i);
}
part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
size_t part_dst_nb[4];
part_dst_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1];
}
ggml_cann_pool_alloc part_dst_allocator;
part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
void* part_dst_buf = part_dst_allocator.get();
acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
// compute part conv transpose 1d
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
// compute the position of part result in final result
int64_t global_start = slice_start;
int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
left_pad_len = global_start;
right_pad_len = dst_len - global_end;
std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
acl_scalar_ptr pad_value = nullptr;
float pad_valueVal = 0.0;
pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
int64_t conv_result_ne[4];
for(int i = 0; i < 4; i++){
conv_result_ne[i] = *(dst->ne + i);
}
size_t conv_result_nb[4];
conv_result_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1];
}
ggml_cann_pool_alloc conv_result_allocator;
conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
void* conv_result_buf = conv_result_allocator.get();
acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
}
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), acl_weight.get(), nullptr, stride.get(), padding.get(),
dilation.get(), true, padding.get(), 1, acl_dst.get(), cubeMathType);
}
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3703,160 +3576,3 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
}
void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // conv_x
ggml_tensor * src1 = dst->src[1]; // conv1d.weight
// This op is currently defined only for F32 in ggml_cpu
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
// Shapes follow ggml_compute_forward_ssm_conv_f32
const int64_t nc = src1->ne[0]; // d_conv
const int64_t ncs = src0->ne[0]; // d_conv - 1 + n_t
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_s = src0->ne[2]; // n_seqs
const int64_t n_t = dst->ne[1]; // tokens per sequence
GGML_ASSERT(dst->ne[0] == nr); // dst: {d_inner, n_t, n_s}
GGML_ASSERT(src1->ne[1] == nr); // weight: {d_conv, d_inner}
GGML_ASSERT(ncs == nc - 1 + n_t); // conv_x: {d_conv - 1 + n_t, d_inner, n_s}
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
// --- Build CANN tensors ---
// 1) Input: conv_x as NCL
//
// src0->ne = { ncs, nr, n_s, 1 } // {L_in, C, N}
// Passing ACL_FORMAT_NCL here means:
// reversed dims -> [N, C, L_in] = [n_s, nr, ncs]
acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
// 2) Weights: depthwise conv kernel, view src1 as {K, 1, C}
//
// src1 original: ne = { nc, nr, 1, 1 } // [K, C, 1, 1]
// we want a view: ne_w = { nc, 1, nr } // [K, 1, C]
// so that reversed dims -> [C, 1, K] which matches
// [out_channels, in_channels/groups, kernel_size]
int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
// Layout: src1 data is [K, C] with
// offset(k, c) = k*nb0 + c*nb1
// We want offset_w(k, 0, c) = k*nb0 + c*nb1,
// so we can reuse nb0 and nb1, and set nb2 = nb1.
size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
acl_tensor_ptr acl_w = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
// 3) Output: dst is { d_inner, n_t, n_s } (CLN)
//
// We need an NCL view of the same buffer:
// desired NCL logical shape: { L_out = n_t, C = nr, N = n_s }
//
// Original CLN layout:
// dst->ne = { nr, n_t, n_s }
// dst->nb[0] = sizeof(float)
// dst->nb[1] = nr * sizeof(float)
// dst->nb[2] = nr * n_t * sizeof(float)
//
// We want offset_new(L, C, N) = offset_orig(C, L, N).
// Choose:
// nb_y[0] = nr * sizeof(float); // step in L
// nb_y[1] = sizeof(float); // step in C
// nb_y[2] = nr * n_t * sizeof(float); // step in N
int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
acl_tensor_ptr acl_y = ggml_cann_create_tensor(
dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
// --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
int64_t strideVal[1] = { 1 };
int64_t paddingVal[1] = { 0 };
int64_t dilationVal[1] = { 1 };
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
const bool transposed = false;
const int64_t groups = nr; // depthwise: one group per inner dim
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
cubeMathType = 1;
#endif
GGML_CANN_CALL_ACLNN_OP(ctx,
Convolution,
acl_x.get(), // input: N, C, L_in = ncs
acl_w.get(), // weight: [C, 1, K] with groups=nr
nullptr, // bias
stride.get(),
padding.get(),
dilation.get(),
transposed,
padding.get(), // output padding (unused for non-transposed)
groups,
acl_y.get(),
cubeMathType);
}
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
ggml_tensor * add_node,
ggml_tensor * rms_norm_node) {
// Get the two input tensors for ADD operation
ggml_tensor * x1 = add_node->src[0];
ggml_tensor * x2 = add_node->src[1];
// Create ACL tensors for the two ADD inputs
acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
// Get epsilon parameter from rms_norm_tensor
float eps;
memcpy(&eps, rms_norm_node->op_params, sizeof(float));
// Build gamma tensor (RMS normalization scaling factor)
// Gamma should match the normalized dimensions (last dimension of x1)
size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
}
acl_tensor_ptr acl_gamma =
get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
acl_gamma_nb, rms_norm_node->type,
1, // dims - only the last dimension
1.0f // value
);
// Build rstdOut tensor (output for normalized standard deviation)
// Shape should be the dimensions that are NOT normalized
int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
}
acl_tensor_ptr acl_rstd =
get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
0.0f // value
);
acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
// Create yOut tensor (final output after RMS normalization)
acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
// Call fused ADD + RMS_NORM operator
GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
eps, // double type
acl_yout.get(), acl_rstd.get(), acl_xout.get());
}

View File

@@ -47,7 +47,6 @@
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_slice.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h>
@@ -935,20 +934,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
*/
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
*
* This function fuses the ADD and RMS_NORM operations into a single kernel call
* for better performance. It first adds two input tensors (x1 + x2), then applies
* RMS normalization to the result.
*
* @param ctx The context for the CANN backend operations.
* @param dst The ADD operation node, contains the two input tensors to be added.
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
* and epsilon parameter.
*/
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
/**
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
*
@@ -1047,8 +1032,6 @@ void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTenso
ggml_backend_cann_context & ctx,
ggml_tensor * dst);
void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
*

View File

@@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info();
void ggml_cann_set_device(int32_t device);
int32_t ggml_cann_get_device();
std::optional<std::string> get_env_as_lowercase(const std::string & name);
std::optional<std::string> get_env(const std::string & name);
bool parse_bool(const std::string & value);
int parse_integer(const std::string & value);
@@ -229,60 +229,6 @@ struct ggml_graph_node_properties {
// op
ggml_op node_op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
/**
* @brief Check if a ggml tensor node matches this property set.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches these previously recorded properties.
*
* @param node The current ggml tensor node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
bool has_matching_properties(ggml_tensor * node) {
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != this->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != this->ne[i]) {
return false;
}
if (node->nb[i] != this->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i]) {
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
return false;
}
}
} else {
if (this->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
};
struct ggml_cann_graph {
@@ -295,79 +241,6 @@ struct ggml_cann_graph {
aclmdlRI graph = nullptr;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
/**
* @brief Create a new CANN graph from a ggml computation graph.
*
* This function creates a new ggml_cann_graph object and fills its node properties
* (operation type, dimensions, strides, input sources, and operation parameters)
* based on the current ggml computation graph.
*
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
* - node address
* - operation type
* - shape (ne) and strides (nb)
* - source tensor addresses
* - operation parameters
*
* @param cgraph The current ggml computation graph.
* @return Pointer to the newly created ggml_cann_graph object.
*/
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
ggml_cann_graph * new_graph = new ggml_cann_graph();
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
ggml_tensor * node = cgraph->nodes[node_idx];
auto & prop = new_graph->ggml_graph_properties[node_idx];
prop.node_address = node->data;
prop.node_op = node->op;
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
return new_graph;
}
/**
* @brief Check whether this CANN graph matches the given ggml computation graph.
*
* This function compares the number of nodes and each node's properties
* (operation type, dimensions, strides, inputs, and operation parameters)
* to determine whether this CANN graph matches the given ggml graph.
*
* @param cgraph The current ggml computation graph.
* @return true if this CANN graph matches the ggml graph; false otherwise.
*/
bool matches_cgraph(ggml_cgraph * cgraph) {
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
return false;
}
}
return true;
}
};
/**
@@ -399,6 +272,15 @@ struct ggml_cann_graph_lru_cache {
cache_list.push_front(new_node);
}
/**
* @brief Move an existing graph to the front of the cache.
* @param node Pointer to the ggml_cann_graph to move.
*/
void move_to_front(ggml_cann_graph * node) {
cache_list.remove(node);
cache_list.push_front(node);
}
/**
* @brief Clear all graphs from the cache (also frees memory).
*/
@@ -413,28 +295,6 @@ struct ggml_cann_graph_lru_cache {
* @brief Destructor that clears the cache and frees all cached graphs.
*/
~ggml_cann_graph_lru_cache() { clear(); }
/**
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
*
* This function iterates through the cached CANN graphs stored in the LRU cache and
* compares them against the given ggml computation graph. If a matching graph is found,
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
* returns nullptr.
*
* @param cgraph The current ggml computation graph.
* @return true if found; false otherwise.
*/
bool find_and_move_to_front(ggml_cgraph * cgraph) {
for (auto & graph_ptr : this->cache_list) {
if (graph_ptr->matches_cgraph(cgraph)) {
cache_list.remove(graph_ptr);
cache_list.push_front(graph_ptr);
return true;
}
}
return false;
}
};
#endif // USE_ACL_GRAPH
@@ -458,9 +318,6 @@ struct ggml_cann_rope_cache {
if (position_select_index_host) {
free(position_select_index_host);
}
if (yarn_ramp_cache) {
ACL_CHECK(aclrtFree(yarn_ramp_cache));
}
}
bool equal(int64_t theta_scale_length,
@@ -513,7 +370,6 @@ struct ggml_cann_rope_cache {
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
void * yarn_ramp_cache = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;

View File

@@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() {
}
/**
* @brief Get the value of the specified environment variable (name) as lowercase.
* @brief Get the value of the specified environment variable (name).
* if not empty, return a std::string object
*/
std::optional<std::string> get_env_as_lowercase(const std::string & name) {
std::optional<std::string> get_env(const std::string & name) {
const char * val = std::getenv(name.c_str());
if (!val) {
return std::nullopt;
@@ -122,7 +122,7 @@ std::optional<std::string> get_env_as_lowercase(const std::string & name) {
* @brief Verify whether the environment variable is a valid value.
*/
bool parse_bool(const std::string & value) {
static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
return valid_values.find(value) != valid_values.end();
}
@@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool.
*/
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
}
/**
@@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool.
*/
explicit ggml_cann_pool_buf(int device) : device(device) {
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
}
/**
@@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
* @return A unique pointer to the created CANN pool.
*/
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
if (mem_pool_type == "prio") {
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
@@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
// Why aclrtSynchronizeDevice?
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
@@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
int64_t ne0 = tensor->ne[0];
// Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
// last line must bigger than 32, because every single op deal at
// least 32 bytes.
@@ -1889,9 +1889,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst);
break;
default:
return false;
}
@@ -2078,39 +2075,161 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
#ifdef USE_ACL_GRAPH
/**
* @brief Check if CANN backend can fuse the specified operation sequence
* @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
*
* This function determines whether an operation sequence starting from the specified node
* can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
* memory access overhead and improve computational efficiency.
* This function creates a new ggml_cann_graph object and fills its node properties
* (operation type, dimensions, strides, input sources, and operation parameters)
* based on the current ggml computation graph.
*
* @param cgraph Pointer to the computation graph
* @param node_idx Index of the starting node in the computation graph
* @param ops Sequence of operation types to check for fusion
* @return true if the operations can be fused
* @return false if the operations cannot be fused
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
* - node address
* - operation type
* - shape (ne) and strides (nb)
* - source tensor addresses
* - operation parameters
*
* After initialization, the new graph is pushed into the LRU cache owned by the
* CANN backend context. The cache takes ownership of the graph and manages its
* lifetime (including deletion upon eviction).
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph.
*/
static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
ggml_cann_graph * new_graph = new ggml_cann_graph();
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
ggml_tensor * node = cgraph->nodes[node_idx];
auto & prop = new_graph->ggml_graph_properties[node_idx];
prop.node_address = node->data;
prop.node_op = node->op;
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
// Insert into the LRU cache (cache takes ownership and will delete it when evicted).
cann_ctx->graph_lru_cache.push(new_graph);
}
/**
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches a previously recorded version.
*
* @param node The current ggml tensor node.
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node,
ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) {
return false;
}
// CANN backend supports fusing ADD + RMS_NORM operations
if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
ggml_tensor * add_node = cgraph->nodes[node_idx];
// TODO: support broadcast for ADD + RMS_NORM
if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
return true;
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i]) {
if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
return false;
}
}
} else {
if (graph_node_properties->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
/**
* @brief Check whether there is a cached CANN graph that matches the current ggml graph.
*
* This function iterates through the cached CANN graphs stored in the LRU cache and
* compares them against the given ggml computation graph. A match requires that the
* number of nodes is the same and that each nodes properties (operation type,
* dimensions, strides, inputs, and operation parameters) are identical.
*
* If a matching graph is found, it is promoted to the front of the LRU cache and the
* function returns true. Otherwise, the function returns false, indicating that a new
* CANN graph needs to be captured.
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph.
* @return true if a matching cached graph exists; false otherwise.
*/
static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache;
for (auto & graph_ptr : lru_cache.cache_list) {
// Skip graphs with a different number of nodes.
if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
continue;
}
// Check if all nodes match.
bool all_match = true;
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
all_match = false;
break;
}
}
if (all_match) {
// update cache_list && renturn graph_ptr
lru_cache.move_to_front(graph_ptr);
return true;
}
}
return false;
}
#endif // USE_ACL_GRAPH
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
@@ -2120,34 +2239,25 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
*/
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
ggml_cgraph * cgraph,
bool use_cann_graph,
bool cann_graph_capture_required) {
bool & use_cann_graph,
bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
if (!use_cann_graph || cann_graph_capture_required) {
if (!use_cann_graph || cann_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
continue;
}
}
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@@ -2164,10 +2274,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#ifdef USE_ACL_GRAPH
if (use_cann_graph) {
GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
if (cann_graph_capture_required) { // End CANN graph capture
if (cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
}
@@ -2197,11 +2306,11 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;
bool graph_capture_required = false;
bool cann_graph_update_required = false;
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
// Do not use acl_graph for prefill.
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -2222,17 +2331,16 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
if (use_cann_graph) {
// If no matching graph is found, the graph needs to be recaptured.
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
if (graph_capture_required) {
cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
if (cann_graph_update_required) {
// If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
cann_ctx->graph_lru_cache.push(new_graph);
add_lru_matched_graph_node_properties(cann_ctx, cgraph);
}
}
#else
bool use_cann_graph = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
return GGML_STATUS_SUCCESS;
}
@@ -2470,7 +2578,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
@@ -2517,8 +2626,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
default:
return false;
}
@@ -2541,6 +2648,27 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
* @brief Determines if a tensor operation should be offloaded to the CANN
* backend.
*
* This function checks if a given tensor operation should be offloaded to the
* CANN backend based on the operation type and the size of the tensor. It
* returns true if the second dimension (ne[1]) of the tensor is greater than or
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
*
* @param backend Pointer to the CANN backend.
* @param op Pointer to the tensor operation to check.
* @return bool Returns true if the operation should be offloaded, otherwise
* false.
*/
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
GGML_UNUSED(dev);
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
}
/**
* @brief Records an event on the CANN backend stream.
*
@@ -2616,7 +2744,6 @@ struct ggml_backend_cann_device_context {
int device;
std::string name;
std::string description;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
@@ -2693,26 +2820,6 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
return ggml_backend_cann_host_buffer_type();
}
/**
* @brief Determines if a tensor operation should be offloaded to the CANN
* backend.
*
* This function checks if a given tensor operation should be offloaded to the
* CANN backend based on the operation type and the size of the tensor. It
* returns true if the second dimension (ne[1]) of the tensor is greater than or
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
*
* @param backend Pointer to the CANN backend.
* @param op Pointer to the tensor operation to check.
* @return bool Returns true if the operation should be offloaded, otherwise
* false.
*/
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
}
/**
* @brief Creates a new event for the CANN backend device.
*
@@ -2829,14 +2936,12 @@ ggml_backend_reg_t ggml_backend_cann_reg() {
if (!initialized) {
aclInit(nullptr);
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_cann_info().device_count; i++) {
ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
dev_ctx->description = aclrtGetSocName();
dev_ctx->device = i;
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_cann_set_device(i);
ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface,
/* .reg = */ &reg,

View File

@@ -561,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.16.0")
set(KLEIDIAI_COMMIT_TAG "v1.14.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321")
set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
@@ -615,7 +615,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED)
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
@@ -660,15 +659,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif()
if (NOT SVE_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c)
endif()
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
endif()

View File

@@ -328,7 +328,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
#include <immintrin.h>
#endif

View File

@@ -18,8 +18,6 @@
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
@@ -71,9 +69,9 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
@@ -154,8 +152,8 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr,
static_cast<const int8_t*>(rhs),
static_cast<const float*>(bias),
@@ -526,61 +524,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
},
#endif
#else
#if defined(__ARM_FEATURE_SVE)
{
/* SVE i8mm GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
},
/* SVE dotprod GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
@@ -635,7 +578,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif // __ARM_FEATURE_MATMUL_INT8
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
@@ -868,27 +811,26 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
ggml_kleidiai_kernels * kernel = nullptr;
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
#if defined(__ARM_FEATURE_SME) || \
defined(__ARM_FEATURE_DOTPROD) || \
defined(__ARM_FEATURE_MATMUL_INT8) || \
defined(__ARM_FEATURE_SVE)
auto try_table = [&](auto & table) {
for (size_t i = 0; i < NELEMS(table) - 1; ++i) {
if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&
table[i].lhs_type == tensor->src[1]->type &&
table[i].rhs_type == tensor->src[0]->type &&
table[i].op_type == tensor->type) {
kernel = &table[i];
return true;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels[i];
break;
}
}
if (!kernel) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels_q8[i];
break;
}
}
return false;
};
if (tensor->src[0]->type == GGML_TYPE_Q8_0) {
try_table(gemm_gemv_kernels_q8);
} else {
try_table(gemm_gemv_kernels);
}
#else
GGML_UNUSED(gemm_gemv_kernels);
@@ -903,10 +845,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || \
defined(__ARM_FEATURE_DOTPROD) || \
defined(__ARM_FEATURE_MATMUL_INT8) || \
defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];

View File

@@ -46,20 +46,13 @@ struct ggml_kleidiai_context {
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
static const char* cpu_feature_to_string(cpu_feature f) {
if (f == CPU_FEATURE_NONE) {
return "NONE";
} else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
return "SME";
} else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
return "SVE";
}
else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
return "I8MM";
} else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
return "DOTPROD";
}
else {
return "UNKNOWN";
switch (f) {
case CPU_FEATURE_NONE: return "NONE";
case CPU_FEATURE_DOTPROD: return "DOTPROD";
case CPU_FEATURE_I8MM: return "I8MM";
case CPU_FEATURE_SVE: return "SVE";
case CPU_FEATURE_SME: return "SME";
default: return "UNKNOWN";
}
}
@@ -75,7 +68,7 @@ static void init_kleidiai_context(void) {
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
if (env_var) {
sme_enabled = atoi(env_var);

View File

@@ -14,6 +14,10 @@
#include <arm_neon.h>
#endif
#if defined(__F16C__)
#include <immintrin.h>
#endif
#if defined(__riscv_v_intrinsic)
#include <riscv_vector.h>
#endif
@@ -654,14 +658,6 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
vec_extract(x[0], 2) + \
vec_extract(x[0], 3); \
}
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
{ \
vector float v = vec_add(vec_add(s0, s1), \
vec_add(s2, s3)); \
v = vec_add(v, vec_sld(v, v, 8)); \
v = vec_add(v, vec_sld(v, v, 4)); \
res += (ggml_float) vec_extract(v, 0); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
@@ -698,29 +694,6 @@ static inline unsigned char ggml_endian_byte(int i) {
r[i - GGML_ENDIAN_BYTE(0)]), \
0, p - GGML_F16_EPR)
//BF16 POWER9
#define GGML_BF16_STEP 16
#define GGML_BF16_EPR 8
#define GGML_BF16x8 vector unsigned short
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
#define GGML_BF16_VEC GGML_BF16x8
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
#if defined(__LITTLE_ENDIAN__)
#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v)))
#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v)))
#else
#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO))
#endif
#define GGML_BF16_FMA_LO(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
#define GGML_BF16_FMA_HI(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
#elif defined(__wasm_simd128__)
#define GGML_SIMD

View File

@@ -237,24 +237,6 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
#endif
#if defined(__POWER9_VECTOR__)
const int np = (n & ~(GGML_BF16_STEP - 1));
if (np > 0) {
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
for (; i < np; i += GGML_BF16_STEP) {
GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i);
GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8);
GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i);
GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8);
GGML_BF16_FMA_LO(sum[0], vx0, vy0);
GGML_BF16_FMA_HI(sum[1], vx0, vy0);
GGML_BF16_FMA_LO(sum[2], vx1, vy1);
GGML_BF16_FMA_HI(sum[3], vx1, vy1);
}
GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]);
}
#endif
for (; i < n; ++i) {
sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
GGML_BF16_TO_FP32(y[i]));

View File

@@ -15,7 +15,6 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture
@@ -35,69 +34,12 @@ if (CUDAToolkit_FOUND)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# The CUDA architecture 120f-virtual would in principle work for Blackwell support
# but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.
# So either a recent CMake version or one with the backported fix is needed.
# The following versions should work:
# - CMake >= v3.31.8 && CMake < v4.0.0
# - CMake >= v4.0.2
# This is NOT documented in the CMake release notes,
# check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
# However, the architectures 120a-real and 121a-real should work with basically any CMake version and
# until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
endif()
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
enable_language(CUDA)
# TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
if (GGML_CUDA_CUB_3DOT2)
include(FetchContent)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG v3.2.0-rc2
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(CCCL)
endif()
# Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
# 12X is forwards-compatible, 12Xa is not.
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
# But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
# So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
set(FIXED_ARCHS "")
foreach(ARCH IN LISTS ${ARCHS})
if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH})
message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}")
list(APPEND FIXED_ARCHS "${FIXED_ARCH}")
else()
list(APPEND FIXED_ARCHS "${ARCH}")
endif()
endforeach()
set(${ARCHS} ${FIXED_ARCHS})
endforeach()
# If we try to compile a "native" build it will use the 12X architectures and fail.
# So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
# But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.
if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$")
set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
endif()
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}")
file(GLOB GGML_HEADERS_CUDA "*.cuh")
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
@@ -160,9 +102,6 @@ if (CUDAToolkit_FOUND)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
@@ -170,9 +109,6 @@ if (CUDAToolkit_FOUND)
endif()
endif()
else()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
endif()
@@ -241,10 +177,6 @@ if (CUDAToolkit_FOUND)
if (NOT MSVC)
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
else()
# CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
# https://github.com/NVIDIA/cccl/pull/6827
list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
endif()
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument

View File

@@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
}
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
@@ -49,49 +49,28 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
}
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
} else {
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
} else {
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
}
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
}
}
#endif // GGML_CUDA_USE_CUB
@@ -162,12 +141,12 @@ static int next_power_of_2(int x) {
return n;
}
void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
static void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);

View File

@@ -1,19 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);
#endif // GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);

View File

@@ -50,10 +50,6 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
#define GGML_CUDA_CC_RUBIN 1300
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -250,10 +246,6 @@ static const char * cu_get_error_str(CUresult err) {
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
# define BLACKWELL_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -262,10 +254,6 @@ static const char * cu_get_error_str(CUresult err) {
#define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
#if defined(TURING_MMA_AVAILABLE)
#define LDMATRIX_TRANS_AVAILABLE
#endif // defined(TURING_MMA_AVAILABLE)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
@@ -328,11 +316,6 @@ static bool cp_async_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool blackwell_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
@@ -530,86 +513,6 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#endif // FP16_AVAILABLE
}
enum class block_reduce_method {
MAX,
SUM,
};
template<block_reduce_method method_t, typename T>
struct block_reduce_policy;
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
template<typename...>
inline constexpr bool ggml_cuda_dependent_false_v = false;
template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
static __device__ T reduce(T val) {
if constexpr(is_any<T, float, float2, half2, int>) {
return warp_reduce_sum(val);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
static __device__ T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return 0.0f;
} else if constexpr (std::is_same_v<T, float2>) {
return make_float2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, half2>) {
return make_half2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, int>) {
return 0;
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
};
template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
static __device__ T reduce(T val) {
if constexpr (is_any<T, float, half2>) {
return warp_reduce_max(val);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
static __device__ T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return -INFINITY;
} else if constexpr (std::is_same_v<T, half2>) {
return make_half2(-INFINITY, -INFINITY);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
};
template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
static __device__ T block_reduce(T val, T * shared_vals) {
val = block_reduce_policy<reduce_method_t, T>::reduce(val);
const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
if (block_size > WARP_SIZE) {
assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
__syncthreads();
val = block_reduce_policy<reduce_method_t, T>::sentinel();
if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
val = shared_vals[lane_id];
}
return block_reduce_policy<reduce_method_t, T>::reduce(val);
}
return val;
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE
@@ -798,28 +701,6 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
const uint8_t sign_bit = (x < 0.0f) << 3;
float ax = fabsf(x) * e;
// Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]);
#pragma unroll
for (int i = 1; i < 8; ++i) {
const float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {
best_err = err;
best_i = i;
}
}
return static_cast<uint8_t>(best_i | sign_bit);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
@@ -1034,16 +915,15 @@ struct ggml_cuda_device_info {
int device_count;
struct cuda_device_info {
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
int warp_size; // Number of threads in a dispatch
bool supports_cooperative_launch; // whether cooperative launch is supported
int warp_size; // Number of threads in a dispatch
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
@@ -1120,7 +1000,7 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH
#endif
struct ggml_cuda_graph_node_properties {
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
@@ -1143,27 +1023,12 @@ struct ggml_cuda_graph {
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_cuda_graph_node_properties> props;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
} else {
number_consecutive_updates = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
}
}
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
}
std::vector<ggml_graph_node_properties> ggml_graph_properties;
#endif
};

View File

@@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
template <cpy_kernel_t cpy_1>
static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
@@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
}
template <typename T>
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);
@@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
const int i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
const int i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
template<typename src_t, typename dst_t>
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
@@ -188,20 +188,19 @@ static void ggml_cpy_scalar_contiguous_cuda(
cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
}
template<typename src_t, typename dst_t, bool transposed = false>
static void ggml_cpy_scalar_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int64_t ne00n, ne01n, ne02n;
int ne00n, ne01n, ne02n;
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00;
ne01n = ne01;
@@ -212,159 +211,143 @@ static void ggml_cpy_scalar_cuda(
ne02n = 1;
}
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
GGML_ASSERT(grid_x < UINT_MAX);
GGML_ASSERT(grid_y < USHRT_MAX);
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
}
static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int64_t num_blocks = ne / QK8_0;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int64_t num_blocks = ne / QK4_0;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int64_t num_blocks = ne / QK4_1;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int64_t num_blocks = ne / QK5_0;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int64_t num_blocks = ne / QK5_1;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int64_t num_blocks = ne / QK4_NL;
GGML_ASSERT(num_blocks < UINT_MAX);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -373,6 +356,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];

View File

@@ -5,7 +5,7 @@
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# include <cub/device/device_scan.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
@@ -16,14 +16,12 @@ static __global__ void cumsum_cub_kernel(
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
__shared__ typename BlockScanT::TempStorage temp_storage;
__shared__ T block_carry;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ T block_carry; // carry from previous tile
const int tid = threadIdx.x;
constexpr int UNROLL_FACTOR = 4;
constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
@@ -41,47 +39,37 @@ static __global__ void cumsum_cub_kernel(
}
__syncthreads();
for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
T items[UNROLL_FACTOR];
T thread_sum = T(0);
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
int64_t idx = start + tid;
T x = (idx < ne00) ? src_row[idx] : T(0);
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
T val = (idx < ne00) ? src_row[idx] : T(0);
thread_sum += val;
items[i] = thread_sum;
}
// Block-wide scan on thread sums
T thread_prefix;
T inclusive;
T block_total;
BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
__syncthreads();
// Add offset to each item and store
T thread_offset = thread_prefix - thread_sum + block_carry;
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
if (idx < ne00) {
dst_row[idx] = items[i] + thread_offset;
}
T final_val = inclusive + block_carry;
// store result
if (idx < ne00) {
dst_row[idx] = final_val;
}
__syncthreads();
// Update carry for next tile
if (tid == 0) {
block_carry += block_total;
}
__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // GGML_CUDA_USE_CUB
}
// Fallback kernel implementation
// Fallback kernel implementation (original)
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
@@ -98,10 +86,10 @@ static __global__ void cumsum_kernel(
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
@@ -119,39 +107,21 @@ static __global__ void cumsum_kernel(
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
// register blocking: process 4 elements per thread to hide latency
// and reduce synchronization overhead
constexpr int num_unroll = 4;
T temp[num_unroll];
for (int64_t start = 0; start < ne00; start += blockDim.x) {
int64_t idx = start + tid;
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
int64_t idx = i + tid * num_unroll;
// thread local sequential scan
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
#pragma unroll
for (int64_t j = 1; j < num_unroll; j++) {
temp[j] = temp[j - 1];
if (idx + j < ne00) {
temp[j] += src_row[idx + j];
} else {
temp[j] += 0;
}
}
// last emenent is sum of all values assigned to thread
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
// Warp inclusive scan
// 1. Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;
// Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
// Exclusive scan of warp sums (warp 0 only)
// 2. Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
@@ -164,55 +134,24 @@ static __global__ void cumsum_kernel(
}
__syncthreads();
// write back results
float carry = *s_carry;
// calculate sum offset for this thread
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
#pragma unroll
for (int32_t j = 0; j < num_unroll; j++) {
if (idx + j < ne00) {
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
}
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
if (idx < ne00) {
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
}
__syncthreads();
// Update carry for next chunk
if (tid == 0) {
*s_carry += *s_chunk_total;
}
__syncthreads();
}
}
#ifdef GGML_CUDA_USE_CUB
template <typename T>
static void cumsum_cub(ggml_cuda_pool & pool,
const T * src,
T * dst,
int64_t ne,
cudaStream_t stream) {
size_t tmp_size = 0;
// Query how much temp storage CUDA UnBound (CUB) needs
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
tmp_size, // reference to size (will be set by CUB)
src, // input pointer
dst, // output pointer
ne, // number of elements
stream // CUDA stream to use
);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
// Perform the inclusive scan
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
}
#endif // GGML_CUDA_USE_CUB
template<typename T>
static void cumsum_cuda(
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
@@ -226,15 +165,6 @@ static void cumsum_cuda(
if (is_contiguous) {
use_cub = true;
const int64_t nrows = ne01 * ne02 * ne03;
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
for (int i=0; i<nrows; i++) {
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
}
return;
}
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
@@ -247,7 +177,7 @@ static void cumsum_cuda(
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
if (use_cub && ne00 >= 1024) {
if (use_cub) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
@@ -273,7 +203,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
case GGML_TYPE_F32:
{
cumsum_cuda(
ctx, (const float *)src0->data, (float *)dst->data,
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],

View File

@@ -11,12 +11,10 @@
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
// by the VKQ accumulators is effectively being shifted up by a factor of 2.
// by the VKQ accumulators is effectively being shifted up by a factor of 8.
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
// Still, the value range should be shifted as much as necessary but as little as possible.
// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
#define FATTN_KQ_MAX_OFFSET 0.6931f
typedef void (* fattn_kernel_t)(
const char * __restrict__ Q,
@@ -59,7 +57,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
__align__(16) half2 tmp[cpy_ne];
half2 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
@@ -309,7 +307,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
} else if constexpr (std::is_same_v<T, float>) {
static_assert(ne % 2 == 0, "bad ne");
__align__(16) half2 tmp[ne/2];
half2 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
@@ -914,15 +912,13 @@ void launch_fattn(
const int nblocks_stream_k = max_blocks;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
}
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.

View File

@@ -98,19 +98,6 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
// TODO tune specifically for RDNA
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if (ampere_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -118,9 +105,6 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
if (turing_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
}
if (amd_wmma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
}
GGML_ASSERT(volta_mma_available(cc));
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
}
@@ -132,8 +116,6 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
#elif defined(VOLTA_MMA_AVAILABLE)
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
#elif defined(AMD_WMMA_AVAILABLE)
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
#else
GGML_UNUSED_VARS(DKQ, DV, ncols);
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -204,23 +186,6 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
}
static constexpr __device__ int get_cols_per_thread() {
#if defined(AMD_WMMA_AVAILABLE)
return 1; // RDNA has a single column.
#else
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
#endif // defined(AMD_WMMA_AVAILABLE)
}
static __host__ int get_cols_per_warp(const int cc) {
if (turing_mma_available(cc) || amd_wmma_available(cc)) {
return 16;
} else {
// Volta
return 32;
}
}
// ------------------------------------------------------------------------------------------------------------------
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -428,10 +393,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int jt,
const int kb0,
const int k_VKQ_sup) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
@@ -448,8 +413,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k_VKQ_0 = kb0 * nbatch_fa;
#if defined(TURING_MMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
#elif defined(AMD_WMMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#else // Volta
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#endif // defined(TURING_MMA_AVAILABLE)
@@ -498,14 +461,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
#else
// swap A and B for CUDA.
// Wide version of KQ_C is column-major => swap A and B.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -522,14 +479,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
// Wide version of KQ_C is column-major => swap A and B.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -580,14 +531,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
#if defined(AMD_WMMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = l % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
@@ -607,14 +552,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
#if defined(AMD_WMMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = l % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
} else {
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
}
@@ -644,14 +583,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
#if defined(AMD_WMMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
// Turing + Volta:
const int KQ_idx = (l/2) % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}
@@ -662,11 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Values per KQ column are spread across 4 threads:
constexpr int offset_first = 2;
constexpr int offset_last = 1;
#elif defined(AMD_WMMA_AVAILABLE)
// Values per KQ column are spread across 2 threads:
constexpr int offset_first = 16;
constexpr int offset_last = 16;
#else // Volta
#else
// Values per KQ column are spread across 2 threads:
constexpr int offset_first = 2;
constexpr int offset_last = 2;
@@ -682,15 +612,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
// Turing + Volta:
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
#if defined(AMD_WMMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = (l/2) % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
} else {
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
}
@@ -714,7 +639,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#if defined(TURING_MMA_AVAILABLE)
if constexpr (cols_per_warp == 8) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
@@ -735,16 +660,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
}
#elif defined(AMD_WMMA_AVAILABLE)
const half2 KQ_max_scale_h2 = make_half2(
KQ_max_scale[0], KQ_max_scale[0]);
#pragma unroll
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
#pragma unroll
for (int l = 0; l < T_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
#else // Volta
const half2 KQ_max_scale_h2 = make_half2(
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -792,10 +707,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
T_A_VKQ A_identity;
make_identity_mat(A_identity);
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
#pragma unroll
@@ -816,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(TURING_MMA_AVAILABLE)
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
#pragma unroll
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -826,26 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
#if defined(LDMATRIX_TRANS_AVAILABLE)
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
#else
// TODO: Try to transpose tile_V when loading gmem to smem.
// Use mma to transpose T_A_VKQ for RDNA.
T_A_VKQ A_trans;
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
mma(A, A_trans, A_identity);
#endif // defined(TURING_MMA_AVAILABLE)
if constexpr (T_B_KQ::I == 8) {
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
} else {
// Wide version of VKQ_C is column-major.
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
#else
// swap A and B for CUDA.
// Wide version of VKQ_C is column-major => swap A and B.
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
@@ -864,7 +761,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
}
}
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(TURING_MMA_AVAILABLE)
if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
@@ -877,7 +774,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_Q, tile_K, tile_V, tile_mask,
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
#if defined(TURING_MMA_AVAILABLE)
@@ -897,15 +794,6 @@ template<> struct mma_tile_sizes<8> {
using T_B_VKQ = tile< 8, 8, half2>; // column-major
using T_C_VKQ = tile<16, 4, half2>; // row-major
};
#elif defined(AMD_WMMA_AVAILABLE)
template<int ncols> struct mma_tile_sizes {
using T_A_KQ = tile<16, 8, half2>; // row-major
using T_B_KQ = tile<16, 8, half2>; // column-major
using T_C_KQ = tile<16, 16, float>; // column-major
using T_A_VKQ = tile<16, 8, half2>; // row-major
using T_B_VKQ = tile<16, 8, half2>; // column-major
using T_C_VKQ = tile<16, 8, half2>; // column-major
};
#else // Volta
template<int ncols> struct mma_tile_sizes {
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -940,7 +828,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr int ncols = ncols1 * ncols2;
@@ -952,7 +840,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
@@ -983,8 +871,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
#if defined(TURING_MMA_AVAILABLE)
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
#elif defined(AMD_WMMA_AVAILABLE)
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
#else // Volta
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
#endif // defined(TURING_MMA_AVAILABLE)
@@ -1124,10 +1010,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// The partial sums are spread across 8/4 threads.
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
#elif defined(AMD_WMMA_AVAILABLE)
// The partial sums are spread across 2 threads.
constexpr int offset_first = 16;
constexpr int offset_last = 16;
#else // Volta
// The partial sums are spread across 2 threads.
constexpr int offset_first = 2;
@@ -1165,7 +1047,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#if defined(TURING_MMA_AVAILABLE)
if constexpr (cols_per_warp == 8) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
#pragma unroll
@@ -1186,15 +1068,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
}
#elif defined(AMD_WMMA_AVAILABLE)
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
#pragma unroll
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
#pragma unroll
for (int l = 0; l < T_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
#else // Volta
const int col = (threadIdx.x / 2) % 2;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1246,10 +1119,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
#elif defined(AMD_WMMA_AVAILABLE)
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
#else // Volta
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1450,7 +1319,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
@@ -1477,7 +1346,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1491,13 +1360,6 @@ static __global__ void flash_attn_ext_f16(
}
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
#if defined(AMD_WMMA_AVAILABLE)
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
NO_DEVICE_CODE;
return;
}
#endif // defined(AMD_WMMA_AVAILABLE)
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
constexpr int ncols = ncols1 * ncols2;
@@ -1611,7 +1473,7 @@ static __global__ void flash_attn_ext_f16(
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
}
template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1630,7 +1492,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
const int nwarps = nthreads / WARP_SIZE;
constexpr bool mla = DKQ == 576;
@@ -1650,34 +1512,29 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
#if defined(GGML_USE_HIP)
using fattn_kernel_ptr_t = const void*;
#else
using fattn_kernel_ptr_t = fattn_kernel_t;
#endif // defined(GGML_USE_HIP)
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
#if !defined(GGML_USE_MUSA)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
shared_memory_limit_raised[id] = true;
}
#endif // !defined(GGML_USE_MUSA)
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
#if !defined(GGML_USE_MUSA)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
shared_memory_limit_raised[id] = true;
}
#endif // !defined(GGML_USE_MUSA)
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
launch_fattn<DV, ncols1, ncols2>

View File

@@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
ggml_cuda_memcpy_1<cpy_nb>(
tile_KV + i*(J/2 + J_padding) + j,
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
@@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
__align__(16) half2 tmp_h2[cpy_ne/2];
half2 tmp_h2[cpy_ne/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
__align__(16) float2 tmp_f2[cpy_ne/2];
float2 tmp_f2[cpy_ne/2];
#pragma unroll
for (int l = 0; l < cpy_ne/2; ++l) {
tmp_f2[l] = __half22float2(tmp_h2[l]);
@@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
__align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
__align__(16) half2 Q_k[cpw][cpy_ne];
half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
half2 Q_k[cpw][cpy_ne];
#else
static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
__align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
__align__(16) float Q_k[cpw][cpy_ne];
float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
float Q_k[cpw][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
@@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#pragma unroll
for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
#ifdef FAST_FP16_AVAILABLE
__align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
#else
__align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
@@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
__align__(16) half2 V_k[(DVp/2)/warp_size];
__align__(16) half2 KQ_k[cpw];
half2 V_k[(DVp/2)/warp_size];
half2 KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
#pragma unroll
@@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
__align__(16) half tmp[KQ_cs];
half tmp[KQ_cs];
ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
&tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
#pragma unroll
@@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#else
#pragma unroll
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
__align__(16) float2 V_k[(DVp/2)/warp_size];
__align__(16) float KQ_k[cpw];
float2 V_k[(DVp/2)/warp_size];
float KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
#pragma unroll
@@ -821,12 +821,12 @@ static __global__ void flash_attn_tile(
__shared__ half2 Q_tmp[ncols * DKQ/2];
__shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
__shared__ half KQ[ncols * nbatch_fa];
__align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
#else
__shared__ float Q_tmp[ncols * DKQ];
__shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
__shared__ float KQ[ncols * nbatch_fa];
__align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
#endif // FAST_FP16_AVAILABLE
float KQ_max[cpw];
@@ -849,7 +849,7 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
__align__(16) float tmp_f[cpy_ne_D] = {0.0f};
float tmp_f[cpy_ne_D] = {0.0f};
ggml_cuda_memcpy_1<sizeof(tmp_f)>
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
@@ -860,7 +860,7 @@ static __global__ void flash_attn_tile(
}
#ifdef FAST_FP16_AVAILABLE
__align__(16) half2 tmp_h2[cpy_ne_D/2];
half2 tmp_h2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
@@ -959,7 +959,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
__align__(16) half2 tmp[cpy_ne_D];
half2 tmp[cpy_ne_D];
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -970,7 +970,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
__align__(16) float tmp[cpy_ne_D];
float tmp[cpy_ne_D];
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile(
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
__align__(16) float2 tmp[cpy_ne_D];
float2 tmp[cpy_ne_D];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);

View File

@@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
return 128;
}
// Currenlty llvm with the amdgcn target does not support unrolling loops
// Currenlty llvm with the amdgcn target dose not support unrolling loops
// that contain a break that can not be resolved at compile time.
#ifdef __clang__
#pragma clang diagnostic push
@@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec(
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
__align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
@@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
__align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
if (ncols == 1 || ic0 + j < int(ne01.z)) {
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);

View File

@@ -18,12 +18,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
}
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
@@ -230,18 +230,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr) {
continue;
}
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (t->nb[i] % 16 != 0) {
gqa_opt_applies = false;
break;
}
}
}
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
const int cc = ggml_cuda_info().devices[device].cc;
@@ -348,31 +337,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_WMMA_F16;
}
if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
if (Q->ne[1] * gqa_ratio_eff <= 8) {
return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
}
return BEST_FATTN_KERNEL_MMA_F16;
}
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {

View File

@@ -19,7 +19,6 @@
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/diag.cuh"
#include "ggml-cuda/fattn.cuh"
@@ -45,7 +44,6 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/top-k.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/topk-moe.cuh"
@@ -203,6 +201,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0;
#ifdef GGML_CUDA_FORCE_MMQ
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif // GGML_CUDA_FORCE_MMQ
#ifdef GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
@@ -233,14 +241,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
#ifndef GGML_USE_MUSA
int supports_coop_launch = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
#else
info.devices[id].supports_cooperative_launch = false;
#endif // !(GGML_USE_MUSA)
#if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock;
@@ -2211,7 +2211,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
@@ -2219,7 +2219,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
@@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
return;
}
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2687,9 +2687,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM:
ggml_cuda_op_sum(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
@@ -2702,9 +2699,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_TOP_K:
ggml_cuda_op_top_k(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
@@ -2714,6 +2708,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
@@ -2853,9 +2850,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
bool use_cuda_graph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
@@ -2915,41 +2912,41 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
return use_cuda_graph;
}
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data;
props->node_op = node->op;
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
graph_node_properties->ne[i] = node->ne[i];
graph_node_properties->nb[i] = node->nb[i];
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
if (node->data != props->node_address &&
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != props->node_op) {
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != props->ne[i]) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
if (node->nb[i] != props->nb[i]) {
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != props->src_address[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
@@ -2957,55 +2954,44 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
}
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true;
}
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool res = false;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
res = true;
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
res = true;
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!props_match) {
res = true;
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
for (int i = 0; i < cgraph->n_leafs; i++) {
bool props_match= true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
}
return res;
return cuda_graph_update_required;
}
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
@@ -3236,11 +3222,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
bool graph_evaluated_or_captured = false;
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
bool is_concurrent_event_active = false;
@@ -3278,7 +3263,6 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
}
}
if (should_launch_concurrent_events) {
// Restore original node order within each concurrent region to enable fusion within streams
@@ -3330,8 +3314,6 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
}
}
} else {
stream_ctx.concurrent_events.clear();
}
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3710,7 +3692,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
ggml_cuda_graph_update_executable(cuda_ctx);
update_cuda_graph_executable(cuda_ctx);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
@@ -3720,48 +3702,60 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
}
}
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
}
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
return cuda_ctx->cuda_graph->is_enabled();
#else
GGML_UNUSED(cuda_ctx);
return false;
#endif // USE_CUDA_GRAPH
}
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_cuda_set_device(cuda_ctx->device);
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#ifdef USE_CUDA_GRAPH
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
if (cuda_ctx->cuda_graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
#endif // USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
@@ -3773,7 +3767,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS;
}
@@ -3806,10 +3807,8 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();
@@ -3817,13 +3816,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
return;
}
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
stream_context.reset();
if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
return;
}
// number of out-degrees for a particular node
std::unordered_map<const ggml_tensor *, int> fan_out;
// reverse mapping of node to index in the cgraph
@@ -3884,12 +3882,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
if (count >= min_fan_out && count <= max_fan_out) {
const int root_node_idx = node_indices[root_node];
// only optimize for attn_norm
// TODO: make this more generic
if (!strstr(root_node->name, "attn_norm")) {
continue;
}
bool is_part_of_event = false;
for (const auto & [start, end] : concurrent_node_ranges) {
if (root_node_idx >= start && root_node_idx <= end) {
@@ -4125,7 +4117,6 @@ struct ggml_backend_cuda_device_context {
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -4553,7 +4544,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
@@ -4619,7 +4610,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return true;
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:
case GGML_OP_ARGSORT:
#ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024;
@@ -4680,9 +4670,11 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
}
static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
const int min_batch_size = 32;
return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
return get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
}
static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@@ -4793,16 +4785,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
features.push_back({ "FA_ALL_QUANTS", "1" });
#endif
{
const auto & info = ggml_cuda_info();
for (int id = 0; id < info.device_count; ++id) {
if (blackwell_mma_available(info.devices[id].cc)) {
features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
break;
}
}
}
#undef _STRINGIFY
#undef STRINGIFY
@@ -4850,7 +4832,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@@ -4864,7 +4845,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id;
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,

View File

@@ -34,11 +34,13 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled())) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled()))) {
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH

View File

@@ -206,16 +206,10 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
// matrix C
#if defined(RDNA3)
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
// matrix C
return 2 * l + (threadIdx.x / 16);
} else {
// matrix A&B
return l;
}
return 2 * l + (threadIdx.x / 16);
#else
// matrix C is the transposed matrix A&B on RDNA4
return ne * (threadIdx.x / 16) + l;
#endif // defined(RDNA3)
} else if constexpr (I == 16 && J == 8) {
@@ -627,21 +621,6 @@ namespace ggml_cuda_mma {
return ret;
}
#elif defined(AMD_WMMA_AVAILABLE)
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
#pragma unroll
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
}
return ret;
}
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
NO_DEVICE_CODE;
return tile<8, 8, half2>{};
}
#else // Volta
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@@ -660,19 +639,6 @@ namespace ggml_cuda_mma {
}
#endif // defined(TURING_MMA_AVAILABLE)
static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
#if defined(RDNA4)
const int row = t.get_i(0);
const int left_right = t.get_j(0) / 4;
const int up_down = row / 8;
const int idx = row % 8;
reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
#else
GGML_UNUSED_VARS(t);
NO_DEVICE_CODE;
#endif // defined(RDNA4)
}
template <int I, int J, typename T, data_layout dl>
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
@@ -912,17 +878,6 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@@ -945,27 +900,6 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
#ifdef BLACKWELL_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
float * Dxi = (float *) D.x;
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
#else
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
#endif // BLACKWELL_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef TURING_MMA_AVAILABLE

View File

@@ -1,4 +1,3 @@
#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@@ -115,9 +114,6 @@ void ggml_cuda_mul_mat_q(
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -127,24 +123,12 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
CUDA_CHECK(cudaGetLastError());
}
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
const int64_t s13 = ne12*s12;
const mmq_args args = {
@@ -190,20 +174,13 @@ void ggml_cuda_mul_mat_q(
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
@@ -259,7 +236,7 @@ void ggml_cuda_op_mul_mat_q(
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
}
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
#ifdef GGML_CUDA_FORCE_CUBLAS
return false;
#endif // GGML_CUDA_FORCE_CUBLAS
@@ -320,10 +297,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (n_experts > 64 || ne11 <= 128) {
return true;
}
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
@@ -333,31 +307,6 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
// High expert counts are almost always better on MMQ due to
// the synchronization overhead in the cuBLAS/hipBLAS path:
// https://github.com/ggml-org/llama.cpp/pull/18202
if (n_experts >= 64) {
return true;
}
// For some quantization types MMQ can have lower peak TOPS than hipBLAS
// so it's only faster for sufficiently small batch sizes:
switch (type) {
case GGML_TYPE_Q2_K:
return ne11 <= 128;
case GGML_TYPE_Q6_K:
return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
default:
return true;
}
}
// For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
return true;
}

View File

@@ -11,7 +11,6 @@ using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -45,15 +44,8 @@ struct block_q8_1_mmq {
};
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
struct block_fp4_mmq {
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
@@ -137,14 +129,6 @@ static int get_mmq_y_host(const int cc) {
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP)
#if defined(RDNA1)
@@ -207,7 +191,6 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -218,8 +201,6 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@@ -228,7 +209,6 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@@ -248,8 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
}
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@@ -782,50 +761,6 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
}
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kbx0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int txi = threadIdx.x;
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
constexpr int rows_per_warp = warp_size / threads_per_row;
const int kbx = txi % threads_per_row;
const int row_in_warp = txi / threads_per_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
const int k0 = kbx * 4;
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
// Load E8M0 scales: pack 2 consecutive scales into one uint32
if (kbx % 2 == 0) {
uint32_t e = bxi->e;
e |= ((bxi + 1)->e << 8);
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -996,78 +931,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
// Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
// Block scale
// Each thread has to point to a 4 byte scale value
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4);
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
scaleA[n][k01 / (2 * QI_MXFP4)] =
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
tile_B B;
uint32_t scaleB; // 2xN scales
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3246,13 +3109,8 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
@@ -3385,26 +3243,17 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(BLACKWELL_MMA_AVAILABLE)
// FP4 tile stores 8 blocks
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
#else
constexpr int ne_block = 4 * QK8_1;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@@ -3418,9 +3267,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
__syncthreads();
{
const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@@ -3552,10 +3401,8 @@ static __global__ void mul_mat_q(
}
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
constexpr int ITER_K = get_iter_k(type);
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = ITER_K / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3616,7 +3463,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3683,7 +3530,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3706,9 +3553,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3866,7 +3711,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}
@@ -4082,4 +3927,4 @@ void ggml_cuda_op_mul_mat_q(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);

View File

@@ -25,8 +25,19 @@ static __global__ void norm_f32(
}
// sum up partial sums
extern __shared__ float2 s_sum2[];
mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
mean_var = warp_reduce_sum(mean_var);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float2 s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
@@ -50,8 +61,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp += x[j];
}
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / group_size;
tmp = 0.0f;
@@ -62,7 +84,18 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp += xi * xi;
}
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
@@ -130,8 +163,22 @@ static __global__ void rms_norm_f32(const float * x,
}
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = 0.0f;
if (lane_id < (block_size / WARP_SIZE)) {
tmp = s_sum[lane_id];
}
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
@@ -259,8 +306,19 @@ static __global__ void l2_norm_f32(
}
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@@ -279,7 +337,7 @@ static void norm_f32_cuda(
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@@ -290,7 +348,7 @@ static void group_norm_f32_cuda(
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
@@ -300,10 +358,10 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@@ -346,12 +404,12 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
@@ -367,14 +425,14 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@@ -402,7 +460,7 @@ static void l2_norm_f32_cuda(
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

View File

@@ -47,131 +47,6 @@ static __global__ void quantize_q8_1(
y[ib].ds = make_half2(d, sum);
}
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (!(amax > 0.0f)) {
return 0;
}
// FP4 E2M1: max exponent (unbiased) is 2.
constexpr int FP4_E2M1_EMAX = 2;
const float e = log2f(amax);
// "even" -> round-to-nearest integer, ties-to-even
const int e_int = __float2int_rn(e);
const int shared_exp = e_int - FP4_E2M1_EMAX;
int biased = shared_exp + 127;
biased = max(biased, 0);
biased = min(biased, 254);
return static_cast<uint8_t>(biased);
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids,
void * __restrict__ vy,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int ne1,
const int ne2) {
constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x;
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
if (warp_start_offset >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int group_id = lane_id_32 / 4;
const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs;
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
uint8_t scales[2];
#pragma unroll
for (int b = 0; b < 2; ++b) {
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12080
const float scaled_val = xi * inv_s;
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
}
#else
// Fallback: manual FP4 conversion using LUT
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
char2 q;
q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
}
#endif // CUDART_VERSION >= 12080
}
if (lane_id_32 == 0) {
// Store 2 scales packed into 1 uint32
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
}
}
template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@@ -315,29 +190,3 @@ void quantize_mmq_q8_1_cuda(
break;
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}

View File

@@ -25,17 +25,3 @@ void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,
int64_t ne00,
int64_t s01,
int64_t s02,
int64_t s03,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
cudaStream_t stream);

View File

@@ -28,8 +28,22 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
}
// sum up partial sums
__shared__ float shared_vals[32];
sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
if (col != 0) {
return;

View File

@@ -1,14 +1,6 @@
#include "common.cuh"
#include "ggml.h"
#include "softmax.cuh"
#ifdef GGML_USE_HIP
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#endif // GGML_USE_HIP
#include <cstdint>
#include <utility>
@@ -75,6 +67,9 @@ static __global__ void soft_max_f32(
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
@@ -99,7 +94,21 @@ static __global__ void soft_max_f32(
}
// find the max value in the block
max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
@@ -117,7 +126,22 @@ static __global__ void soft_max_f32(
}
// find the sum of exps in the block
tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
if (sinks) {
tmp += expf(sinks[i02] - max_val);
@@ -136,113 +160,6 @@ static __global__ void soft_max_f32(
dst[col] = vals[col] * inv_sum;
}
}
// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_maxs,
float * __restrict__ tmp_sums,
const soft_max_params p) {
namespace cg = cooperative_groups;
const cg::grid_group g = cg::this_grid();
const int tid = threadIdx.x;
const int col_start = blockIdx.x * blockDim.x + tid;
const int n_elem_per_thread = 4;
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
__shared__ float shared_vals[32];
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
local_max = fmaxf(local_max, local_vals[i]);
}
col += step_size * n_elem_per_thread;
}
// Compute CTA-level max
local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Store CTA-level max to GMEM
if (tid == 0) {
tmp_maxs[blockIdx.x] = local_max;
}
g.sync();
// Compute compute global max from CTA-level maxs
assert(gridDim.x < blockDim.x); // currently we only support this case
if (tid < gridDim.x) {
local_max = tmp_maxs[tid];
} else {
local_max = -INFINITY;
}
local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
const float tmp = expf(local_vals[i] - local_max);
tmp_expf += tmp;
dst[idx] = tmp;
}
}
col += step_size * n_elem_per_thread;
}
// Reduce divisor within CTA
tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Store CTA-level sum to GMEM
if (tid == 0) {
tmp_sums[blockIdx.x] = tmp_expf;
}
g.sync();
// Compute global sum from CTA-level sums
if (tid < gridDim.x) {
tmp_expf = tmp_sums[tid];
} else {
tmp_expf = 0.0f;
}
tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
dst[idx] = local_vals[i] / tmp_expf;
}
}
col += step_size * n_elem_per_thread;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
@@ -299,31 +216,9 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_maxs,
float * __restrict__ tmp_sums,
const soft_max_params p)
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
// currently only support synchronizing the complete grid if not launched as a cluster group
// (which requires CC > 9.0)
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
{
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
tmp_sums, p);
}
}
template <typename T>
static void soft_max_f32_cuda(const float * x,
const T * mask,
const float * sinks,
float * dst,
const soft_max_params & params,
cudaStream_t stream,
[[maybe_unused]] ggml_backend_cuda_context & ctx) {
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
@@ -341,25 +236,8 @@ static void soft_max_f32_cuda(const float * x,
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
// Parallelize across SMs for top-p/dist-sampling
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
params.scale == 1.0f && params.max_bias == 0.0f) {
ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
(void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
} else {
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
soft_max_f32<false, 0, 0>
<<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
@@ -437,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}

View File

@@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
#endif // __clang__
// assumes as many threads as d_state
template <int c_factor, int d_state>
template <int splitH, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@@ -125,25 +125,20 @@ __global__ void __launch_bounds__(d_state, 1)
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int warp_idx = blockIdx.x * c_factor + warp;
const int head_idx = warp_idx / d_head;
const int head_off = (warp_idx % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int head_idx = (blockIdx.x * splitH) / d_head;
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
@@ -152,42 +147,80 @@ __global__ void __launch_bounds__(d_state, 1)
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[c_factor];
float state_sum = 0.0f;
float state[splitH];
// for the parallel accumulation
__shared__ float stateC[splitH * d_state];
#pragma unroll
for (int j = 0; j < c_factor; j++) {
state[j] = s0_warp[WARP_SIZE * j + lane];
for (int j = 0; j < splitH; j++) {
state[j] = s0_block[j * d_state + threadIdx.x];
}
for (int64_t i = 0; i < n_tok; i++) {
// NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
// Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// TODO: only calculate B and C once per head group
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
state_sum = 0.0f;
const float dA = expf(dt_soft_plus * A_warp[0]);
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
// across d_head
#pragma unroll
for (int j = 0; j < c_factor; j++) {
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
state_sum += state[j] * C_val;
for (int j = 0; j < splitH; j++) {
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B * x_dt);
stateC[j * d_state + threadIdx.x] = state[j] * C;
}
// parallel accumulation for output
state_sum = warp_reduce_sum(state_sum);
__syncthreads();
if (lane == 0) {
y_warp[i * stride_y] = state_sum;
// parallel accumulation for stateC
// TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
}
}
// write back the state
#pragma unroll
for (int j = 0; j < c_factor; j++) {
s_warp[WARP_SIZE * j + lane] = state[j];
for (int j = 0; j < splitH; j++) {
s_block[j * d_state + threadIdx.x] = state[j];
}
}
@@ -198,24 +231,27 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
constexpr int threads = 128;
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else if (d_state == 256) { // Falcon-H1
constexpr int threads = 256;
constexpr int num_warps = threads/WARP_SIZE;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
const int threads = 256;
// NOTE: can be any power of two between 8 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@@ -224,7 +260,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
}
} else {
// Mamba-1
constexpr int threads = 128;
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);

View File

@@ -1,96 +0,0 @@
#include "argsort.cuh"
#include "top-k.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
# include <cuda/iterator>
# define CUB_TOP_K_AVAILABLE
using namespace cub;
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
#endif // GGML_CUDA_USE_CUB
#ifdef CUB_TOP_K_AVAILABLE
static void top_k_cub(ggml_cuda_pool & pool,
const float * src,
int * dst,
const int ncols,
const int k,
cudaStream_t stream) {
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
cuda::execution::output_ordering::unsorted);
auto stream_env = cuda::stream_ref{ stream };
auto env = cuda::std::execution::env{ stream_env, requirements };
auto indexes_in = cuda::make_counting_iterator(0);
size_t temp_storage_bytes = 0;
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
env);
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
ncols, k, env);
}
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
#endif // CUB_TOP_K_AVAILABLE
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
int * dst_d = (int *) dst->data;
cudaStream_t stream = ctx.stream();
// are these asserts truly necessary?
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t k = dst->ne[0];
ggml_cuda_pool & pool = ctx.pool();
#ifdef CUB_TOP_K_AVAILABLE
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
// https://github.com/NVIDIA/cccl/issues/6391
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
for (int i = 0; i < nrows; i++) {
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
}
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
// Fall back to argsort + copy
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
int * tmp_dst = temp_dst_alloc.get();
if (shared_mem > max_shared_mem || ncols > 1024) {
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
}
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
cudaMemcpyDeviceToDevice, stream));
#else // GGML_CUDA_USE_CUB
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
int * tmp_dst = temp_dst_alloc.get();
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
cudaMemcpyDeviceToDevice, stream));
#endif
}

View File

@@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -10,10 +10,6 @@
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDART_VERSION >= 12080
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH

View File

@@ -45,11 +45,9 @@
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
@@ -72,7 +70,6 @@
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
@@ -138,8 +135,6 @@
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
#define cudaFuncSetAttribute hipFuncSetAttribute
#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED

View File

@@ -61,7 +61,6 @@
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ extern "C" {
#include <AEEStdErr.h>
#include <inttypes.h>
#include <remote.h>
#include <rpcmem.h>
#include <stdbool.h>
/* Offset to differentiate HLOS and Hexagon error codes.

View File

@@ -17,22 +17,21 @@ add_library(${HTP_LIB} SHARED
main.c
htp_iface_skel.c
worker-pool.c
hex-dma.c
htp-dma.c
hvx-sigmoid.c
hvx-inverse.c
hvx-exp.c
hvx-utils.c
matmul-ops.c
binary-ops.c
unary-ops.c
softmax-ops.c
act-ops.c
rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
build_idl(htp_iface.idl ${HTP_LIB})

View File

@@ -2,20 +2,27 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define htp_act_preamble3 \
const uint32_t ne00 = src0->ne[0]; \
@@ -69,7 +76,7 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
@@ -78,16 +85,13 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
uint32_t src0_nrows_per_thread) {
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@@ -101,6 +105,12 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
int is_aligned = 1;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0;
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
@@ -117,86 +127,42 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
//swiglu(x) = x1 * sigmoid(x0)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);
hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
(const uint8_t *) src1_spad_ptr, nc);
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
if (opt_path) {
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
(uint8_t *) dst, nc);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
@@ -205,16 +171,15 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
uint32_t src0_nrows_per_thread) {
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
const size_t src1_row_size = nb11;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
@@ -226,116 +191,72 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
return;
}
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
bool src1_valid = src1->ne[0];
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
data_src1 = data_src0;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
const int32_t swapped = op_params[1];
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int nc = (src1_valid) ? ne00 : ne00 / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
"%zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);
// y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);
// x1 (dst_spad_data) = alpha * (x)
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);
// x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);
// out = x * sigmoid(alpha * x) * (y + 1.f)
hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
(const uint8_t *) src1_spad_ptr, nc);
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
if (!src1) {
src0 += swapped ? nc : 0;
src1 += swapped ? 0 : nc;
}
}
dma_queue_flush(dma_queue);
// x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
// y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
// x1 (dst_spad_data) = alpha * (x)
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
// x2 (dst_spad_data) = expf(-x1)
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
// x3 (dst_spad_data) = x2 + 1.f
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
// x4 (dst_spad_data) = 1 / x3
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
// out_glu(dst_spad_data) = x * x4
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
// out = out_glu * (y + 1.f);
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
@@ -351,8 +272,8 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
const uint32_t src0_nrows = ne01 * ne02 * ne03;
@@ -408,9 +329,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0);
hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -435,23 +356,22 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
uint32_t src0_nrows_per_thread) {
htp_act_preamble2;
uint64_t t1, t2;
@@ -459,8 +379,6 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const uint32_t src0_nrows = ne01 * ne02 * ne03;
@@ -472,94 +390,67 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * data_src0 = (const uint8_t *) src0->data;
uint8_t * data_dst = (uint8_t *) dst->data;
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
int is_aligned = 1;
int opt_path = 0;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0;
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
dst_row_size, dst_row_size_aligned, block_size);
if (1 == opt_path) {
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02,
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread);
}
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
}
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
@@ -576,21 +467,21 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_UNARY_SILU:
act_op_func = unary_silu_f32;
act_op_func = unary_silu_fp32;
op_type = "silu-f32";
break;
case HTP_OP_GLU_SWIGLU:
act_op_func = glu_swiglu_f32;
act_op_func = glu_swiglu_fp32;
op_type = "swiglu-f32";
break;
case HTP_OP_GLU_SWIGLU_OAI:
act_op_func = glu_swiglu_oai_f32;
act_op_func = glu_swiglu_oai_fp32;
op_type = "swiglu-oai-f32";
break;
case HTP_OP_UNARY_GELU:
act_op_func = unary_gelu_f32;
act_op_func = unary_gelu_fp32;
op_type = "gelu-f32";
break;
default:
@@ -610,9 +501,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
src1_row_size = src0_row_size;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
@@ -663,7 +554,7 @@ int op_activations(struct htp_ops_context * octx) {
switch (octx->src0.type) {
case HTP_TYPE_F32:
err = execute_op_activations_f32(octx);
err = execute_op_activations_fp32(octx);
break;
default:

View File

@@ -2,25 +2,36 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems);
typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
const uint8_t * src1,
uint8_t * data_dst,
const int num_elems);
static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa };
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
#define htp_binary_preamble \
const struct htp_tensor * src0 = &octx->src0; \
@@ -87,8 +98,9 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx,
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
is_aligned = 0;
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
@@ -118,24 +130,24 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx,
const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
if (ir + 1 < src0_end_row) {
hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
if (src1_row_size == src0_row_size) {
hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1);
htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
}
}
const uint32_t nr0 = ne00 / ne10;
if (nr0 > 1) {
if ((1 == is_aligned) && (nr0 == ne00)) {
hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0);
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
} else {
for (uint32_t r = 0; r < nr0; r++) {
memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
}
}
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00);
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
} else {
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
}
src0_ptr += src0_row_size;
@@ -173,6 +185,11 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
@@ -193,9 +210,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
if (ir + 1 < src0_end_row) {
hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
if (src1_row_size == src0_row_size) {
hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1);
htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
}
}
@@ -204,9 +221,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
for (uint32_t r = 0; r < nr0; r++) {
memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
}
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00);
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
} else {
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
}
}
@@ -282,9 +299,9 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;

View File

@@ -1,251 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
struct htp_copy_context {
struct htp_ops_context * octx;
uint32_t src0_type_size;
uint32_t src0_block_size;
uint32_t dst_type_size;
uint32_t dst_block_size;
uint32_t src0_blocks_per_row;
uint32_t dst_blocks_per_row;
uint32_t src0_nrows_per_thread;
void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
};
#define cpy_preamble \
struct htp_tensor *src0 = &octx->src0; \
struct htp_tensor *dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne03 = src0->ne[3]; \
\
const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t ne3 = dst->ne[3]; \
\
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
\
const uint32_t nr = ne01;
static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
cpy_preamble;
// parallelize by src0 rows
const uint32_t dr = ct->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
// copy by rows
for (uint32_t i03 = 0; i03 < ne03; i03++) {
for (uint32_t i02 = 0; i02 < ne02; i02++) {
#pragma unroll(2)
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
}
}
}
}
static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
cpy_preamble;
// parallelize by src0 rows
const uint32_t dr = ct->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
// dst counters
int64_t k10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
// number of blocks in a row
const int64_t nk00 = ct->src0_blocks_per_row;
const int64_t nk0 = ct->dst_blocks_per_row;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
k10 += nk00 * ir0;
while (k10 >= nk0) {
k10 -= nk0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t k00 = 0; k00 < nk00; k00++) {
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
if (++k10 == nk0) {
k10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
k10 += nk00 * (ne01 - ir1);
while (k10 >= nk0) {
k10 -= nk0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
}
static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
cpy_preamble;
// parallelize by src0 rows
const uint32_t dr = ct->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
// copy by rows
for (uint32_t i03 = 0; i03 < ne03; i03++) {
for (uint32_t i02 = 0; i02 < ne02; i02++) {
#pragma unroll(2)
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
}
}
}
}
static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
cpy_preamble;
// parallelize by src0 rows
const uint32_t dr = ct->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
// copy by rows
for (uint32_t i03 = 0; i03 < ne03; i03++) {
for (uint32_t i02 = 0; i02 < ne02; i02++) {
#pragma unroll(2)
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
}
}
}
}
static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
struct htp_copy_context *ct = (struct htp_copy_context *) data;
ct->copy(ct, ct->octx, n, i);
}
int op_cpy(struct htp_ops_context * octx) {
cpy_preamble;
struct htp_copy_context ct;
ct.octx = octx;
switch (src0->type) {
case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
default:
return HTP_STATUS_NO_SUPPORT;
}
switch (dst->type) {
case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
default:
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
const bool sametype = (src0->type == dst->type);
const bool transposed = (nb00 > nb01) || (nb0 > nb1);
const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
if (sametype && sameshape) {
ct.copy = cpy_thread_sametype_sameshape;
} else if (sameshape) {
/**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
ct.copy = cpy_thread_f16_f32_sameshape;
else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
ct.copy = cpy_thread_f32_f16_sameshape;
else
return HTP_STATUS_NO_SUPPORT;
} else if (sametype) {
ct.copy = cpy_thread_sametype_reshape;
} else {
return HTP_STATUS_NO_SUPPORT;
}
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
return HTP_STATUS_OK;
}

View File

@@ -1,561 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <math.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// MAD: y (F32) += x (F16) * v (float)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S = hvx_vec_splat_f16(s);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
}
if (nloe) {
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
HVX_Vector xs = Q6_V_lo_W(xs_p);
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
const uint32_t neq2 = q->ne[2];
const uint32_t neq3 = q->ne[3];
const uint32_t nek0 = k->ne[0];
const uint32_t nek1 = k->ne[1];
const uint32_t nek2 = k->ne[2];
const uint32_t nek3 = k->ne[3];
const uint32_t nev0 = v->ne[0];
const uint32_t nev1 = v->ne[1];
const uint32_t nev2 = v->ne[2];
const uint32_t nev3 = v->ne[3];
const uint32_t nbq1 = q->nb[1];
const uint32_t nbq2 = q->nb[2];
const uint32_t nbq3 = q->nb[3];
const uint32_t nbk1 = k->nb[1];
const uint32_t nbk2 = k->nb[2];
const uint32_t nbk3 = k->nb[3];
const uint32_t nbv1 = v->nb[1];
const uint32_t nbv2 = v->nb[2];
const uint32_t nbv3 = v->nb[3];
const uint32_t ne1 = dst->ne[1];
const uint32_t ne2 = dst->ne[2];
const uint32_t ne3 = dst->ne[3];
const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
const uint32_t dr = (nr + nth - 1) / nth;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, nr);
if (ir0 >= ir1) return;
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
// Clear accumulator
hvx_splat_f32_a(spad_a, 0, DV);
float * VKQ32 = (float *) spad_a;
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// Wait for DMA
uint8_t * k_base = dma_queue_pop(dma).dst; // K
uint8_t * v_base = dma_queue_pop(dma).dst; // V
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
// Inner loop processing the block from VTCM
uint32_t ic = 0;
// Process in blocks of 32 (VLEN_FP32)
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
}
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 3. Mask
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 4. Online Softmax Update
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
float m_block = hvx_vec_get_f32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
S = S * ms;
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
float p_sum = hvx_vec_get_f32(p_sum_vec);
S += p_sum;
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
}
}
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
}
if (mask) {
const float m_val = m_base[ic];
s_val += slope * m_val;
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s_val - M);
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
}
}
// sinks
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
S = S * ms + vs;
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
// Store result
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// dst is permuted
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
if (dst->type == HTP_TYPE_F32) {
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
} else if (dst->type == HTP_TYPE_F16) {
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
if (octx->ctx->vtcm_size < total_spad) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
}
return HTP_STATUS_OK;
}

View File

@@ -1,106 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t nr = ne10 * ne11 * ne12;
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
const uint32_t dr = octx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i01 >= ne01) {
// invalid index, skip for now to avoid crash
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
return HTP_STATUS_OK;
}
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
get_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
return HTP_STATUS_OK;
}

View File

@@ -1,77 +0,0 @@
#ifndef HEX_DUMP_H
#define HEX_DUMP_H
#include <HAP_farf.h>
static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n && p < p_end; i++) {
p += snprintf(p, p_end - p, "%d, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n && p < p_end; i++) {
p += snprintf(p, p_end - p, "%d, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%.6f, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) {
uint32_t n0 = n / 16;
uint32_t n1 = n % 16;
uint32_t i = 0;
for (; i < n0; i++) {
hex_dump_f32_line(pref, x + (16 * i), 16);
}
if (n1) {
hex_dump_f32_line(pref, x + (16 * i), n1);
}
}
static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
uint32_t n0 = n / 16;
uint32_t n1 = n % 16;
uint32_t i = 0;
for (; i < n0; i++) {
hex_dump_f16_line(pref, x + (16 * i), 16);
}
if (n1) {
hex_dump_f16_line(pref, x + (16 * i), n1);
}
}
#endif /* HEX_DUMP_H */

View File

@@ -1,37 +0,0 @@
#ifndef HEX_FASTDIV_H
#define HEX_FASTDIV_H
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
struct fastdiv_values {
uint32_t mp;
uint32_t l;
};
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
struct fastdiv_values result = { 0, 0 };
// compute L = ceil(log2(d));
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
++(result.l);
}
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
return result;
}
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
// Compute high 32 bits of n * mp
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
// add n, apply bit shift
return (hi + n) >> vals->l;
}
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
return n - fastdiv(n, vals) * d;
}
#endif /* HEX_FASTDIV_H */

View File

@@ -1,51 +0,0 @@
#ifndef HEX_UTILS_H
#define HEX_UTILS_H
#include <stdbool.h>
#include <stdint.h>
#include "hexagon_types.h"
#include "hex-fastdiv.h"
#include "hex-dump.h"
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint64_t hex_get_cycles() {
uint64_t cycles = 0;
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
return cycles;
}
static inline uint64_t hex_get_pktcnt() {
uint64_t pktcnt;
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
return pktcnt;
}
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);
}
#endif /* HEX_UTILS_H */

View File

@@ -1,7 +1,7 @@
#ifndef HTP_CTX_H
#define HTP_CTX_H
#include "hex-dma.h"
#include "htp-dma.h"
#include "worker-pool.h"
#include <assert.h>
@@ -11,6 +11,11 @@
#define HTP_MAX_NTHREADS 10
// FIXME: move these into matmul-ops
#define HTP_SPAD_SRC0_NROWS 16
#define HTP_SPAD_SRC1_NROWS 16
#define HTP_SPAD_DST_NROWS 2
// Main context for htp DSP backend
struct htp_context {
dspqueue_t queue;

View File

@@ -1,4 +1,4 @@
#include "hex-dma.h"
#include "htp-dma.h"
#include <stdbool.h>
#include <stdlib.h>

View File

@@ -2,6 +2,7 @@
#define HTP_DMA_H
#include <HAP_farf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <stdbool.h>
#include <stdint.h>

View File

@@ -36,8 +36,6 @@ enum htp_data_type {
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT
};
@@ -59,11 +57,6 @@ enum htp_op {
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
HTP_OP_CPY = 18,
INVALID
};
@@ -144,8 +137,6 @@ struct htp_general_req {
struct htp_tensor src0; // Input0 tensor
struct htp_tensor src1; // Input1 tensor
struct htp_tensor src2; // Input2 tensor
struct htp_tensor src3; // Input3 tensor
struct htp_tensor src4; // Input4 tensor
struct htp_tensor dst; // Output tensor
// should be multiple of 64 bytes (cacheline)
@@ -161,6 +152,6 @@ struct htp_general_rsp {
};
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
#define HTP_MAX_PACKET_BUFFERS 8
#define HTP_MAX_PACKET_BUFFERS 4
#endif /* HTP_MSG_H */

View File

@@ -4,17 +4,15 @@
#include "htp-ctx.h"
#include "htp-msg.h"
#include "worker-pool.h"
#include "ops-utils.h"
#include <assert.h>
#include <stdint.h>
#include <hex-fastdiv.h>
// ggml-common.h must be included prior to this header
struct htp_spad {
uint8_t * data;
size_t stride;
size_t size;
size_t size_per_thread;
};
@@ -28,14 +26,11 @@ struct htp_ops_context {
struct htp_tensor src0;
struct htp_tensor src1;
struct htp_tensor src2;
struct htp_tensor src3;
struct htp_tensor src4;
struct htp_tensor dst;
struct htp_spad src0_spad;
struct htp_spad src1_spad;
struct htp_spad src2_spad;
struct htp_spad src3_spad;
struct htp_spad dst_spad;
worker_pool_context_t * wpool; // worker pool
@@ -54,35 +49,6 @@ struct htp_ops_context {
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src3_div1; // fastdiv values for ne1
struct fastdiv_values src3_div2; // fastdiv values for ne2
struct fastdiv_values src3_div3; // fastdiv values for ne3
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};
@@ -94,9 +60,5 @@ int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
int op_rope(struct htp_ops_context * octx);
int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */

View File

@@ -1,457 +0,0 @@
#ifndef HVX_ARITH_H
#define HVX_ARITH_H
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <math.h>
#include "hvx-base.h"
#include "hex-utils.h"
//
// Binary operations (add, mul, sub)
//
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \
} \
if (nloe) { \
HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// ADD variants
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
}
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
}
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
}
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
}
// SUB variants
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
}
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
}
// MUL variants
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
}
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
}
// Dispatchers
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_add_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_add_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_uu(dst, src0, src1, num_elems);
}
}
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
assert((unsigned long) src2 % 128 == 0);
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
const uint32_t elem_size = sizeof(float);
const uint32_t epv = 128 / elem_size;
const uint32_t nvec = num_elems / epv;
const uint32_t nloe = num_elems % epv;
uint32_t i = 0;
_Pragma("unroll(4)")
for (; i < nvec; i++) {
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
}
if (nloe) {
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
}
}
// Scalar Operations
#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector v = vsrc[i]; \
vdst[i] = scalar_op_macro(v); \
} \
if (nloe) { \
HVX_Vector v = vsrc[i]; \
v = scalar_op_macro(v); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
#define HVX_OP_ADD_SCALAR(v) \
({ \
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
HVX_Vector out = HVX_OP_ADD(v, val_vec); \
Q6_V_vmux_QVV(pred_inf, inf, out); \
})
#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
// Add Scalar Variants
static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
}
static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
}
static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
}
static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
static const float kInf = INFINITY;
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
}
// Sub Scalar Variants
static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
}
// Mul Scalar Variants
static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
}
static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_add_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_add_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_add_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_add_scalar_f32_uu(dst, src, val, num_elems);
}
}
static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_mul_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
}
}
static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_sub_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
}
}
// MIN Scalar variants
#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_min_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_min_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_min_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_min_scalar_f32_uu(dst, src, val, num_elems);
}
}
// CLAMP Scalar variants
#define HVX_OP_CLAMP_SCALAR(v) \
({ \
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
})
static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
} else {
hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
}
}
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
#undef hvx_arith_loop_body
#undef HVX_OP_ADD_SCALAR
#undef HVX_OP_SUB_SCALAR
#undef HVX_OP_MUL_SCALAR
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#endif // HVX_ARITH_H

View File

@@ -1,167 +0,0 @@
#ifndef HVX_BASE_H
#define HVX_BASE_H
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hvx-types.h"
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
// Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
uint32_t left_off = (size_t) dst & 127;
uint32_t right_off = left_off + n;
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst);
HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off);
if (right_off > 128) {
Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v);
// all 1's
qr = Q6_Q_vcmp_eq_VbVb(v, v);
}
ql_not = Q6_Q_or_QQn(ql_not, qr);
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v);
}
static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) {
assert((unsigned long) dst % 128 == 0);
HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n));
Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v);
}
static inline HVX_Vector hvx_vec_splat_f32(float v) {
union { float f; uint32_t i; } u = { .f = v };
return Q6_V_vsplat_R(u.i);
}
static inline HVX_Vector hvx_vec_splat_f16(float v) {
union { __fp16 f; uint16_t i; } u = { .f = v };
return Q6_Vh_vsplat_R(u.i);
}
static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
// vdelta control to replicate first 4 bytes across all elements
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
HVX_Vector ctrl = *(HVX_Vector *) repl;
return Q6_V_vdelta_VV(v, ctrl);
}
static inline float hvx_vec_get_f32(HVX_Vector v) {
float __attribute__((aligned(128))) x;
hvx_vec_store_a(&x, 4, v);
return x;
}
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
return Q6_V_vand_VV(v, mask);
}
static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) {
// neg by setting the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
return Q6_V_vxor_VV(v, mask);
}
static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) {
// abs by clearing the fp32 sign bit
HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
return Q6_V_vand_VV(v, mask);
}
static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) {
#if __HVX_ARCH__ > 75
return Q6_Vsf_vfneg_Vsf(v);
#else
// neg by setting the fp32 sign bit
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
return Q6_V_vxor_VV(v, mask);
#endif // __HVX_ARCH__ > 75
}
static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00);
const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF);
// get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s
HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp);
HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp));
return Q6_Q_and_QQ(p_exp, p_frac);
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
v = Q6_V_vmux_QVV(nan, neg_inf, v);
#endif
return v;
}
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
{
HVX_Vector const vzero = Q6_V_vzero();
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
return ret;
}
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
{
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
// Ideally should just be Q6_Vh_equals_Vhf(vin)
// but that instruction does not do proper rounding.
// convert to qf32, multiplying by 1.0 in the process.
HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
// 'in-range' values are +/32752.
// add 192K to it, convert to sf
HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
// for in-range cases, result is {163858... 229360} so the exponent is always 144.
// if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
// Start by <<10 to get the final 'sign' bit in bit 15...
vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
// now round down to 16
return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
}
#endif /* HVX_BASE_H */

View File

@@ -1,247 +0,0 @@
#ifndef HVX_COPY_H
#define HVX_COPY_H
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
#define hvx_splat_loop_body(dst_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
\
uint32_t nvec = n / (128 / elem_size); \
uint32_t nloe = n % (128 / elem_size); \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = src; \
} \
if (nloe) { \
vec_store((void *) &vdst[i], nloe * elem_size, src); \
} \
} while(0)
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
assert((unsigned long) dst % 128 == 0);
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
}
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
}
static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
}
static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) {
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
}
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \
if (nloe) { \
vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \
} \
} while(0)
// Generic copy routines
static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
assert((unsigned long) dst % 128 == 0);
hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
assert((unsigned long) src % 128 == 0);
hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_aa(dst, src, n, sizeof(__fp16));
}
// copy n fp16 elements : source is aligned, destination is potentially unaligned
static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_au(dst, src, n, sizeof(__fp16));
}
// copy n fp16 elements : source is aligned, destination is potentially unaligned
static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_ua(dst, src, n, sizeof(__fp16));
}
// copy n fp16 elements : source is aligned, destination is potentially unaligned
static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_uu(dst, src, n, sizeof(__fp16));
}
// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_aa(dst, src, n, sizeof(float));
}
// copy n fp32 elements : source is aligned, destination is unaligned
static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_ua(dst, src, n, sizeof(float));
}
// copy n fp32 elements : source is unaligned, destination is aligned
static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_au(dst, src, n, sizeof(float));
}
// copy n fp32 elements : source is unaligned, destination unaligned
static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_uu(dst, src, n, sizeof(float));
}
//// fp32 -> fp16
#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector zero = Q6_V_vsplat_R(0); \
\
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
} \
if (nloe) { \
HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned
static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
//// fp16 -> fp32
#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector one = hvx_vec_splat_f16(1.0); \
\
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (i = 0; i < nvec; ++i) { \
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \
vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \
} \
\
if (nloe) { \
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
\
HVX_Vector vd = Q6_V_lo_W(p); \
i = 2 * i; \
\
if (nloe >= 32) { \
vdst[i] = Q6_Vsf_equals_Vqf32(vd); \
nloe -= 32; ++i; vd = Q6_V_hi_W(p); \
} \
\
if (nloe) { \
vd = Q6_Vsf_equals_Vqf32(vd); \
hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \
} \
} \
} while(0)
// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned
static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned
static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned
static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned
static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
#endif // HVX_COPY_H

View File

@@ -1,132 +0,0 @@
#ifndef HVX_DUMP_H
#define HVX_DUMP_H
#include <HAP_farf.h>
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hvx-types.h"
static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v };
const uint32_t n0 = n / 16;
const uint32_t n1 = n % 16;
int i = 0;
for (; i < n0; i++) {
hex_dump_f16_line(pref, u.fp16 + (16 * i), 16);
}
if (n1) {
hex_dump_f16_line(pref, u.fp16 + (16 * i), n1);
}
}
static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
hvx_vec_dump_f16_n(pref, v, 64);
}
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
union {
HVX_Vector v;
float d[32];
} u = { .v = v };
const uint32_t n0 = n / 16;
const uint32_t n1 = n % 16;
int i = 0;
for (; i < n0; i++) {
hex_dump_f32_line(pref, u.d + (16 * i), 16);
}
if (n1) {
hex_dump_f32_line(pref, u.d + (16 * i), n1);
}
}
static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
float d[32];
} u = { .v = v };
FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
}
static void hvx_vec_dump_f32(char * pref, HVX_Vector v) {
hvx_vec_dump_f32_n(pref, v, 32);
}
static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
int32_t d[32];
} u = { .v = v };
for (int i = 0; i < 32 / 16; i++) {
hex_dump_int32_line(pref, u.d + (16 * i), 16);
}
}
static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
int32_t d[32];
} u = { .v = v };
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
}
static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
int8_t d[128];
} u = { .v = v };
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
}
static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
int8_t d[128];
} u = { .v = v };
for (int i = 0; i < 128 / 16; i++) {
hex_dump_int8_line(pref, u.d + (16 * i), 16);
}
}
static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
union {
HVX_Vector v;
uint8_t d[128];
} u = { .v = v };
for (int i = 0; i < 128 / 16; i++) {
hex_dump_uint8_line(pref, u.d + (16 * i), 16);
}
}
static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
typedef union {
HVX_Vector v;
int8_t d[128];
} U;
U u0 = { .v = v0 };
U u1 = { .v = v1 };
for (int i = 0; i < n; i++) {
if (u0.d[i] != u1.d[i]) {
return false;
}
}
return true;
}
#endif /* HVX_DUMP_H */

View File

@@ -0,0 +1,94 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
HVX_Vector out = hvx_vec_exp_fp32(in_vec);
return Q6_V_vmux_QVV(pred0, inf, out);
}
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
int unaligned_addr = 0;
int unaligned_loop = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
unaligned_addr = 1;
}
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
unaligned_loop = 1;
FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
}
HVX_Vector vec_out = Q6_V_vzero();
static const float kInf = INFINITY;
static const float kMaxExp = 88.02f; // log(INF)
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
if (0 == unaligned_loop) {
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf);
}
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf);
}
}
}
if (left_over > 0) {
const float * srcf = (float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf;
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
} else {
vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf);
}
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
}
}

View File

@@ -1,215 +0,0 @@
#ifndef HVX_EXP_H
#define HVX_EXP_H
#include <stdbool.h>
#include <stdint.h>
#include "hvx-base.h"
#include "hvx-floor.h"
#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x41a00000) // 20.0
#define EXP_RANGE_L (0xc1a00000) // -20.0
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
HVX_Vector z_qf32_v;
HVX_Vector x_v;
HVX_Vector x_qf32_v;
HVX_Vector y_v;
HVX_Vector k_v;
HVX_Vector f_v;
HVX_Vector epsilon_v;
HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
HVX_Vector E_const;
HVX_Vector zero_v = Q6_V_vzero();
// exp(x) is approximated as follows:
// f = floor(x/ln(2)) = floor(x*log2(e))
// epsilon = x - f*ln(2)
// exp(x) = exp(epsilon+f*ln(2))
// = exp(epsilon)*exp(f*ln(2))
// = exp(epsilon)*2^f
//
// Since epsilon is close to zero, it can be approximated with its Taylor series:
// exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
// Preserving the first eight elements, we get:
// exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
// = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
HVX_Vector temp_v = in_vec;
// Clamp inputs to (-20.0, 20.0)
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
// f_v is the floating point result and k_v is the integer result
f_v = hvx_vec_floor_f32(epsilon_v);
k_v = hvx_vec_truncate_f32(f_v);
x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
// x = x - f_v * logn2;
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
// normalize before every QFloat's vmpy
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
// z = x * x;
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
// y = E4 + E5 * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
E_const = Q6_V_vsplat_R(EXP_COEFF_4);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = E3 + y * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_3);
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = E2 + y * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_2);
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = E1 + y * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_1);
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = E0 + y * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_0);
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = x + y * z;
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
// y = y + 1.0;
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
// insert exponents
// y = ldexpf(y, k);
// y_v += k_v; // qf32
// modify exponent
y_v = Q6_Vsf_equals_Vqf32(y_v);
// add k_v to the exponent of y_v
HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
// exponent cannot be negative; if overflow is detected, result is set to zero
HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
return y_v;
}
static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
HVX_Vector out = hvx_vec_exp_f32(in_vec);
return Q6_V_vmux_QVV(pred0, inf, out);
}
static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
int unaligned_addr = 0;
int unaligned_loop = 0;
if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {
unaligned_addr = 1;
}
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
unaligned_loop = 1;
}
HVX_Vector vec_out = Q6_V_vzero();
static const float kInf = INFINITY;
static const float kMaxExp = 88.02f; // log(INF)
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
if (0 == unaligned_loop) {
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);
*p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
} else {
*p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);
}
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
} else {
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);
}
}
}
if (left_over > 0) {
const float * srcf = (float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf;
if (true == negate) {
HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
} else {
vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);
}
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
}
}
#endif /* HVX_EXP_H */

View File

@@ -1,100 +0,0 @@
#ifndef HVX_FLOOR_H
#define HVX_FLOOR_H
#include <stdbool.h>
#include <stdint.h>
#include "hvx-base.h"
#define IEEE_VSF_EXPLEN (8)
#define IEEE_VSF_EXPBIAS (127)
#define IEEE_VSF_EXPMASK (0xFF)
#define IEEE_VSF_MANTLEN (23)
#define IEEE_VSF_MANTMASK (0x7FFFFF)
#define IEEE_VSF_MIMPMASK (0x800000)
static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) {
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
HVX_Vector const_zero_v = Q6_V_vzero();
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
expval_v &= IEEE_VSF_EXPMASK;
expval_v -= IEEE_VSF_EXPBIAS;
// negative exp == fractional value
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift
HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa
HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0
vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer
vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0
HVX_Vector neg_vout = -vout;
vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives
return (vout);
}
static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) {
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
HVX_Vector const_zero_v = Q6_V_vzero();
HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
expval_v &= IEEE_VSF_EXPMASK;
expval_v -= IEEE_VSF_EXPBIAS;
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
// if expval < 0 (q_negexp) // <0, floor is 0
// if vin > 0
// floor = 0
// if vin < 0
// floor = -1
// if expval < mant_len (q_expltmn) // >0, but fraction may exist
// get sign (q_negative)
// mask >> expval // fraction bits to mask off
// vout = ~(mask) // apply mask to remove fraction
// if (qneg) // negative floor is one less (more, sign bit for neg)
// vout += ((impl_mask) >> expval)
// if (mask && vin)
// vout = vin
// else // already an integer
// ; // no change
// compute floor
mask_mant_v >>= expval_v;
HVX_Vector neg_addin_v = mask_impl_v >> expval_v;
HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set
HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear
HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits
vout = in_vec;
vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant
vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values
vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0
vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1
return vout;
}
#endif /* HVX_FLOOR_H */

View File

@@ -0,0 +1,72 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
}
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
int unaligned_addr = 0;
int unaligned_loop = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
unaligned_addr = 1;
}
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
unaligned_loop = 1;
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
}
static const uint32_t kNanInfMask = 0x7f800000;
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask);
if (0 == unaligned_loop) {
HVX_Vector * p_vec_in = (HVX_Vector *) src;
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask);
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
}
}
if (left_over > 0) {
const float * srcf = (float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf;
HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
}
}

View File

@@ -1,176 +0,0 @@
#ifndef HVX_INVERSE_H
#define HVX_INVERSE_H
#include <HAP_farf.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
// ====================================================
// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5
// Order:3; continuity: True; Ends forced: True
// Mode: unsigned; Result fractional bits: 14
// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05
// 32769 -32706 31252 -10589
// 32590 -30635 22793 -4493
// 32066 -27505 16481 -2348
// 31205 -24054 11849 -1306
static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
// input is 0..0xffff representing 0.0 .. 1.0
HVX_Vector p;
p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
return p; // signed result, 14 fractional bits
}
// Find reciprocal of fp16.
// (1) first, convert to fp32, multiplying by 1.0; this is done to
// handle denormals. Ignoring sign and zero, result should be at
// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
// (exponent in range [103,143])
// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
// (4) convert that to fp16
// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
// the result with the max value.
static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) {
HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
HVX_Vector avals = Q6_V_vand_VV(vals, em_mask);
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals);
// is too small to 1/x ? for 'standard' fp16, this would be 0x101
HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0
HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
// bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
// likewise extract the upper 16 from each, containing the exponents in range 103..142
HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
//Get exponent in IEEE 32-bit representation
exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
// so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
// We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
// Use poly to transform to 1/x, with 14 fractional bits
//
HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros
// Get mantissa for 16-bit represenation
HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
//Compute Reciprocal Exponent
HVX_Vector exp_recip =
Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
//Convert it for 16-bit representation
exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
//Merge exponent and mantissa for reciprocal
HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
// map 'small' inputs to standard largest value 0x7bff
recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
// add sign back
recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
return recip;
}
static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) {
HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
HVX_Vector two_sf = hvx_vec_splat_f32(2.0);
// First approximation
HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
HVX_Vector r_qf;
// Refine
r_qf = Q6_Vqf32_vmpy_VsfVsf(
i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
return Q6_Vsf_equals_Vqf32(r_qf);
}
static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
HVX_Vector out = hvx_vec_inverse_f32(v_sf);
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
}
#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
} \
if (nloe) { \
HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \
} \
} while(0)
static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_inverse_f32_aa(dst, src, num_elems);
} else {
hvx_inverse_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_inverse_f32_ua(dst, src, num_elems);
} else {
hvx_inverse_f32_uu(dst, src, num_elems);
}
}
}
#endif // HVX_INVERSE_H

View File

@@ -1,225 +0,0 @@
#ifndef HVX_REDUCE_H
#define HVX_REDUCE_H
#include <math.h>
#include <stdbool.h>
#include <stdint.h>
#include <assert.h>
#include "hex-utils.h"
#include "hvx-base.h"
#include "hvx-types.h"
static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // int32
HVX_Vector sum = in, sum_t;
while (width < total) {
sum_t = Q6_V_vror_VR(sum, width); // rotate right
sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
width = width << 1;
}
return sum;
}
static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_i32(in, 32);
}
static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // fp32 nbytes
HVX_Vector sum = in, sum_t;
while (width < total) {
sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
width = width << 1;
}
return sum;
}
static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_qf32(in, 32);
}
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // fp32 nbytes
HVX_Vector sum = in, sum_t;
while (width < total) {
sum_t = Q6_V_vror_VR(sum, width); // rotate right
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
width = width << 1;
}
return sum;
}
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_f32(in, 32);
}
static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
unsigned total = 128; // total vec nbytes
unsigned width = 2; // fp16 nbytes
HVX_Vector _max = in, _max_t;
while (width < total) {
_max_t = Q6_V_vror_VR(_max, width); // rotate right
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
width = width << 1;
}
return _max;
}
static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
unsigned total = 128; // total vec nbytes
unsigned width = 2; // fp32 nbytes
HVX_Vector _max_t;
_max = Q6_Vhf_vmax_VhfVhf(in, _max);
while (width < total) {
_max_t = Q6_V_vror_VR(_max, width); // rotate right
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
width = width << 1;
}
return _max;
}
static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
unsigned total = 128; // total vec nbytes
unsigned width = 4; // fp32 nbytes
HVX_Vector _max = in, _max_t;
while (width < total) {
_max_t = Q6_V_vror_VR(_max, width); // rotate right
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
width = width << 1;
}
return _max;
}
static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
unsigned total = 128; // total vec nbytes
unsigned width = 4; // fp32 nbytes
HVX_Vector _max_t;
_max = Q6_Vsf_vmax_VsfVsf(in, _max);
while (width < total) {
_max_t = Q6_V_vror_VR(_max, width); // rotate right
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
width = width << 1;
}
return _max;
}
#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
do { \
src_type * restrict vsrc = (src_type *) src; \
HVX_Vector acc = init_vec; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = num_elems / epv; \
const uint32_t nloe = num_elems % epv; \
\
uint32_t i = 0; \
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
acc = vec_op(acc, vsrc[i]); \
} \
if (nloe) { \
const float * srcf = (const float *) src + i * epv; \
HVX_Vector in = *(HVX_UVector *) srcf; \
HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \
acc = vec_op(acc, temp); \
} \
HVX_Vector v = reduce_op(acc); \
return scalar_reduce(v); \
} while(0)
#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
// Max variants
static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
assert((unsigned long) src % 128 == 0);
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
}
static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
}
static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
if (hex_is_aligned((void *) src, 128)) {
return hvx_reduce_max_f32_a(src, num_elems);
} else {
return hvx_reduce_max_f32_u(src, num_elems);
}
}
// Sum variants
static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = Q6_V_vsplat_R(0);
assert((unsigned long) src % 128 == 0);
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
}
static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = Q6_V_vsplat_R(0);
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
}
static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
if (hex_is_aligned((void *) src, 128)) {
return hvx_reduce_sum_f32_a(src, num_elems);
} else {
return hvx_reduce_sum_f32_u(src, num_elems);
}
}
// Sum of squares variants
static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = Q6_V_vsplat_R(0);
assert((uintptr_t) src % 128 == 0);
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
}
static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
HVX_Vector init_vec = Q6_V_vsplat_R(0);
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
}
static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
if (hex_is_aligned((void *) src, 128)) {
return hvx_sum_of_squares_f32_a(src, num_elems);
} else {
return hvx_sum_of_squares_f32_u(src, num_elems);
}
}
#undef hvx_reduce_loop_body
#undef HVX_REDUCE_MAX_OP
#undef HVX_REDUCE_SUM_OP
#undef HVX_REDUCE_MAX_SCALAR
#undef HVX_REDUCE_SUM_SCALAR
#undef HVX_SUM_SQ_OP
#endif /* HVX_REDUCE_H */

View File

@@ -1,133 +0,0 @@
#ifndef HVX_SCALE_H
#define HVX_SCALE_H
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
HVX_Vector vs = hvx_vec_splat_f32(scale); \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; ++i) { \
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
vdst[i] = Q6_Vsf_equals_Vqf32(v); \
} \
if (nloe) { \
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
} \
} while(0)
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
assert((size_t) dst % 128 == 0);
hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
assert((size_t) src % 128 == 0);
hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
if (((size_t) dst & 127) == 0) {
if (((size_t) src & 127) == 0) {
hvx_scale_f32_aa(dst, src, n, scale);
} else {
hvx_scale_f32_au(dst, src, n, scale);
}
} else {
if (((size_t) src & 127) == 0) {
hvx_scale_f32_ua(dst, src, n, scale);
} else {
hvx_scale_f32_uu(dst, src, n, scale);
}
}
}
#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
HVX_Vector vs = hvx_vec_splat_f32(scale); \
HVX_Vector vo = hvx_vec_splat_f32(offset); \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; ++i) { \
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
vdst[i] = Q6_Vsf_equals_Vqf32(v); \
} \
if (nloe) { \
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
} \
} while(0)
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
assert((size_t) dst % 128 == 0);
hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
assert((size_t) src % 128 == 0);
hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
if (((size_t) dst & 127) == 0) {
if (((size_t) src & 127) == 0) {
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
} else {
hvx_scale_offset_f32_au(dst, src, n, scale, offset);
}
} else {
if (((size_t) src & 127) == 0) {
hvx_scale_offset_f32_ua(dst, src, n, scale, offset);
} else {
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
}
}
}
#endif // HVX_SCALE_H

View File

@@ -0,0 +1,49 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#if 0
// Reference algo used in hvx-utils
static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems)
{
const float c1 = 0.03138777;
const float c2 = 0.276281267;
const float c_log2f = 1.442695022;
int32_t store_ints[32];
float store_floats[3][32];
for (int i = 0; i < num_elems; i++)
{
float v = src0[i];
v *= c_log2f*0.5;
int intPart = (int)v;
float x = (v - intPart);
float xx = x * x;
float v1 = c_log2f + c2 * xx;
float v2 = x + xx * c1 * x;
float v3 = (v2 + v1);
*((int*)&v3) += intPart << 24;
float v4 = v2 - v1;
float v5 = v3 - v4;
float res = v3 / v5;
dst[i] = res;
}
}
#endif

View File

@@ -1,114 +0,0 @@
#ifndef HVX_SIGMOID_H
#define HVX_SIGMOID_H
#include "hvx-base.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
HVX_Vector res = hvx_vec_inverse_f32(v5);
res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
return Q6_Vsf_equals_Vqf32(res);
}
static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
HVX_Vector one,
HVX_Vector max_exp,
HVX_Vector min_exp) {
const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
out = Q6_V_vmux_QVV(pred_max, out, one);
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
}
static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
HVX_Vector two = hvx_vec_splat_f32(2.0f);
HVX_Vector one = hvx_vec_splat_f32(1.0f);
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
return Q6_Vsf_equals_Vqf32(res);
}
#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector one = hvx_vec_splat_f32(1.f); \
const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \
const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \
\
const uint32_t epv = 128 / sizeof(float); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
} \
if (nloe) { \
HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
} \
} while(0)
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
#endif /* HVX_SIGMOID_H */

View File

@@ -1,60 +0,0 @@
#ifndef HVX_SQRT_H
#define HVX_SQRT_H
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hvx-base.h"
#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
//Algorithm :
// x2 = input*0.5
// y = * (long *) &input
// y = 0x5f3759df - (y>>2)
// y = y*(threehalfs - x2*y*y)
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
HVX_Vector x2, y, ypower2, temp;
x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
y = Q6_Vw_vasr_VwR(in_vec, 1);
y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
// 1st iteration
ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
// 2nd iteration
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
// 3rd iteration
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
return Q6_Vsf_equals_Vqf32(temp);
}
#endif /* HVX_SQRT_H */

View File

@@ -1,36 +0,0 @@
#ifndef HVX_TYPES_H
#define HVX_TYPES_H
#include <stdbool.h>
#include <stdint.h>
#include <hexagon_types.h>
#define SIZEOF_FP32 (4)
#define SIZEOF_FP16 (2)
#define VLEN (128)
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
typedef union {
HVX_Vector v;
uint8_t b[VLEN];
uint16_t h[VLEN_FP16];
uint32_t w[VLEN_FP32];
__fp16 fp16[VLEN_FP16];
float fp32[VLEN_FP32];
} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
typedef struct {
HVX_Vector v[2];
} HVX_Vector_x2;
typedef struct {
HVX_Vector v[4];
} HVX_Vector_x4;
typedef struct {
HVX_Vector v[8];
} HVX_Vector_x8;
#endif /* HVX_TYPES_H */

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#pragma clang diagnostic ignored "-Wunused-function"
#include <HAP_farf.h>
#include <HAP_perf.h>
#define FARF_ERROR 1
#define FARF_HIGH 1
#define FARF_MEDIUM 0
#define FARF_LOW 0
#include <AEEStdErr.h>
#include <dspqueue.h>
#include <HAP_compute_res.h>
#include <HAP_etm_config.h>
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_power.h>
#include <HAP_ps.h>
#include <qurt.h>
@@ -15,14 +19,13 @@
#include <remote.h>
#include <string.h>
#include "hex-dma.h"
#include "hex-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "ops-utils.h"
#include "worker-pool.h"
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
@@ -359,14 +362,14 @@ struct profile_data {
static inline void profile_start(struct profile_data * d) {
d->usecs = HAP_perf_get_qtimer_count();
d->cycles = hex_get_cycles();
d->pkts = hex_get_pktcnt();
d->cycles = htp_get_cycles();
d->pkts = htp_get_pktcnt();
}
static inline void profile_stop(struct profile_data * d) {
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
d->cycles = hex_get_cycles() - d->cycles;
d->pkts = hex_get_pktcnt() - d->pkts;
d->cycles = htp_get_cycles() - d->cycles;
d->pkts = htp_get_pktcnt() - d->pkts;
}
static int send_htp_rsp(struct htp_context * c,
@@ -440,82 +443,6 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_cpy(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_get_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_matmul_id_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@@ -741,7 +668,7 @@ static void proc_rope_req(struct htp_context * ctx,
uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
int write_idx = n_bufs - 1;
int write_idx = (n_bufs == 4) ? 3 : 2;
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[write_idx].fd;
@@ -789,102 +716,6 @@ static void proc_rope_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_set_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_flash_attn_ext_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
// Setup Op context
struct htp_ops_context octx;
memset(&octx, 0, sizeof(octx));
octx.ctx = ctx;
octx.n_threads = ctx->n_threads;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.src2 = req->src2;
octx.src3 = req->src3;
octx.src4 = req->src4;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.src2.data = (uint32_t) bufs[2].ptr;
int last_buf = 3;
if (octx.src3.ne[0]) {
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
}
if (octx.src4.ne[0]) {
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
}
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_flash_attn_ext(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
struct dspqueue_buffer rsp_buf = bufs[last_buf];
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
}
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context;
@@ -959,7 +790,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break;
case HTP_OP_RMS_NORM:
case HTP_OP_SCALE:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
@@ -1003,38 +833,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_rope_req(ctx, &req, bufs, n_bufs);
break;
case HTP_OP_FLASH_ATTN_EXT:
if (!(n_bufs >= 4 && n_bufs <= 6)) {
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
continue;
}
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
break;
case HTP_OP_SET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad set-rows-req buffer list");
continue;
}
proc_set_rows_req(ctx, &req, bufs);
break;
case HTP_OP_GET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad get-rows-req buffer list");
continue;
}
proc_get_rows_req(ctx, &req, bufs);
break;
case HTP_OP_CPY:
if (n_bufs != 2) {
FARF(ERROR, "Bad cpy-req buffer list");
continue;
}
proc_cpy_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,149 @@
#ifndef OPS_UTILS_H
#define OPS_UTILS_H
#include "htp-msg.h"
#ifndef MAX
# define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
# define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint64_t htp_get_cycles() {
uint64_t cycles = 0;
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
return cycles;
}
static inline uint64_t htp_get_pktcnt() {
uint64_t pktcnt;
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
return pktcnt;
}
static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
struct fastdiv_values {
uint32_t mp;
uint32_t l;
};
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
struct fastdiv_values result = { 0, 0 };
// compute L = ceil(log2(d));
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
++(result.l);
}
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
return result;
}
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
// Compute high 32 bits of n * mp
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
// add n, apply bit shift
return (hi + n) >> vals->l;
}
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
return n - fastdiv(n, vals) * d;
}
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
}
static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n && p < p_end; i++) {
p += snprintf(p, p_end - p, "%d, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n && p < p_end; i++) {
p += snprintf(p, p_end - p, "%d, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
char str[1024], *p = str, *p_end = str + sizeof(str);
p += snprintf(p, p_end - p, "%s: ", pref);
for (int i = 0; i < n; i++) {
p += snprintf(p, p_end - p, "%.6f, ", x[i]);
}
FARF(HIGH, "%s\n", str);
}
static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
uint32_t n0 = n / 16;
uint32_t n1 = n % 16;
uint32_t i = 0;
for (; i < n0; i++) {
htp_dump_fp32_line(pref, x + (16 * i), 16);
}
if (n1) {
htp_dump_fp32_line(pref, x + (16 * i), n1);
}
}
static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
uint32_t n0 = n / 16;
uint32_t n1 = n % 16;
uint32_t i = 0;
for (; i < n0; i++) {
htp_dump_fp16_line(pref, x + (16 * i), 16);
}
if (n1) {
htp_dump_fp16_line(pref, x + (16 * i), n1);
}
}
#endif /* OPS_UTILS_H */

View File

@@ -2,20 +2,27 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
#define HTP_ROPE_TYPE_NORMAL 0
@@ -363,8 +370,8 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
is_aligned = 0;
}
@@ -420,9 +427,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;

View File

@@ -1,164 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <math.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#define set_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t ne1 = octx->dst.ne[1]; \
\
const uint32_t nr = ne01;
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
// copy row
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
}
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
set_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;
}
return HTP_STATUS_OK;
}

View File

@@ -2,20 +2,27 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define htp_softmax_preamble3 \
const uint32_t ne00 = src0->ne[0]; \
@@ -93,8 +100,8 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
uint8_t * restrict dst_curr = dst;
const uint8_t * restrict mask_curr = mask;
HVX_Vector scale_vec = hvx_vec_splat_f32(scale);
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
int step_of_1 = num_elems >> 5;
@@ -127,9 +134,9 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);
HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
HVX_Vector zero_v = Q6_V_vzero();
HVX_Vector one_v = hvx_vec_splat_f32(1.0);
HVX_Vector one_v = hvx_vec_splat_fp32(1.0);
int step_of_1 = num_elems >> 5;
@@ -139,7 +146,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
}
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
max_vec = hvx_vec_repl4(v);
#pragma unroll(4)
@@ -147,18 +154,18 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
HVX_Vector v1 = v_src[i];
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));
HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
v_pad[i] = v3;
}
v = hvx_vec_reduce_sum_qf32(sum_vec);
v = hvx_vec_qf32_reduce_sum(sum_vec);
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec);
HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
#pragma unroll(4)
@@ -174,11 +181,11 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
uint8_t * restrict spad,
const int num_elems,
const float max) {
hvx_sub_scalar_f32(spad, src, max, num_elems);
hvx_sub_scalar_f32(src, max, spad, num_elems);
hvx_exp_f32(spad, dst, num_elems, false);
float sum = hvx_reduce_sum_f32(dst, num_elems);
float sum = hvx_self_sum_f32(dst, num_elems);
return sum;
}
@@ -231,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
if (mp_f32) {
if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
@@ -248,10 +255,10 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
if (1 == opt_path) {
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
} else {
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
}
}
}
@@ -283,7 +290,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
int is_aligned = 1;
int opt_path = 0;
if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0;
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
@@ -338,9 +345,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;

View File

@@ -2,20 +2,28 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <HAP_ps.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <qurt_thread.h>
#include <string.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define htp_unary_preamble \
const uint32_t ne00 = src->ne[0]; \
@@ -47,7 +55,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
int step_of_1 = num_elems >> 5;
#pragma unroll(4)
@@ -57,15 +65,15 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
}
HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v);
HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v);
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
#pragma unroll(4)
for (int i = 0; i < step_of_1; i++) {
@@ -75,31 +83,6 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
}
}
static void scale_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
@@ -116,7 +99,7 @@ static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
}
if (1 == opt_path) {
@@ -127,7 +110,7 @@ static void rms_norm_htp_f32(const float * restrict src,
const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon);
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
}
}
}
@@ -160,8 +143,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
is_aligned = 0;
FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
@@ -178,9 +162,6 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default:
break;
@@ -214,10 +195,6 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32";
break;
case HTP_OP_SCALE:
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
@@ -231,8 +208,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;

View File

@@ -7,6 +7,10 @@
#include <stdlib.h>
#include <string.h>
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include "HAP_farf.h"
#define WORKER_THREAD_STACK_SZ (2 * 16384)

View File

@@ -1,153 +0,0 @@
#ifndef OP_DESC_H
#define OP_DESC_H
#define GGML_COMMON_IMPL_CPP
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include <string>
#include <stdio.h>
struct op_desc {
char strides[64 * GGML_MAX_SRC];
char dims[64 * GGML_MAX_SRC];
char types[16 * GGML_MAX_SRC];
char buffs[64 * GGML_MAX_SRC];
char names[64 * GGML_MAX_SRC];
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
} else {
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
}
}
void format_op_dims(char * str, const struct ggml_tensor * t) {
char * p = str;
// append src0 and src1 (if any)
if (t->src[0]) {
p += format_tensor_dims(p, t->src[0]);
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
p += sprintf(p, " x ");
p += format_tensor_dims(p, t->src[i]);
}
p += sprintf(p, " -> ");
}
// format self dims separately for better visual alignment
char self[64];
format_tensor_dims(self, t);
p += sprintf(p, "%s", self);
}
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
const char * c = ggml_is_contiguous(t) ? "" : "!";
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
} else {
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
}
}
void format_op_strides(char * str, const struct ggml_tensor * t) {
char * p = str;
// append src0 and src1 (if any)
if (t->src[0]) {
p += format_tensor_strides(p, t->src[0]);
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
p += sprintf(p, " x ");
p += format_tensor_strides(p, t->src[i]);
}
p += sprintf(p, " -> ");
}
// format self dims separately for better visual alignment
char self[64];
format_tensor_strides(self, t);
p += sprintf(p, "%s", self);
}
void format_op_types(char * str, const struct ggml_tensor * t) {
char * p = str;
// append src0 and src1 (if any)
if (t->src[0]) {
p += sprintf(p, "%s", ggml_type_name(t->src[0]->type));
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", ggml_type_name(t->src[i]->type));
}
p += sprintf(p, " -> ");
}
p += sprintf(p, "%s", ggml_type_name(t->type));
}
const char * tensor_buff_name(const struct ggml_tensor * t) {
if (t->buffer) {
return ggml_backend_buffer_name(t->buffer);
}
return "NONE";
}
void format_op_buffs(char * str, const struct ggml_tensor * t) {
char * p = str;
// append src0 and src1 (if any)
if (t->src[0]) {
p += sprintf(p, "%s", tensor_buff_name(t->src[0]));
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", tensor_buff_name(t->src[i]));
}
p += sprintf(p, " -> ");
}
p += sprintf(p, "%s", tensor_buff_name(t));
}
void format_op_names(char * str, const struct ggml_tensor * t) {
char * p = str;
// append src0 and src1 (if any)
if (t->src[0]) {
p += sprintf(p, "%s", t->src[0]->name);
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", t->src[i]->name);
}
p += sprintf(p, " -> ");
}
p += sprintf(p, "%s", t->name);
}
void format(const ggml_tensor * op) {
format_op_dims(dims, op);
format_op_strides(strides, op);
format_op_types(types, op);
format_op_buffs(buffs, op);
format_op_names(names, op);
}
op_desc() {}
op_desc(const ggml_tensor * op) { format(op); }
};
#endif // OP_DESC_H

View File

@@ -24,6 +24,10 @@
#include <arm_neon.h>
#endif
#if defined(__F16C__)
#include <immintrin.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@@ -23,6 +23,11 @@ if (GGML_METAL_NDEBUG)
add_compile_definitions(GGML_METAL_NDEBUG)
endif()
# copy metal files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
@@ -32,12 +37,12 @@ if (GGML_METAL_EMBED_LIBRARY)
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
@@ -57,11 +62,6 @@ if (GGML_METAL_EMBED_LIBRARY)
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
else()
# copy metal files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
if (GGML_METAL_SHADER_DEBUG)
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air

View File

@@ -1684,60 +1684,3 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->type == GGML_TYPE_I64);
char base[256];
char name[256];
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_COUNT_EQUAL);
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
GGML_ASSERT(op->type == GGML_TYPE_I64);
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
char base[256];
char name[256];
int nsg = 1;
while (32*nsg < ne00 && nsg < 32) {
nsg *= 2;
}
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32 * sizeof(int32_t);
res.nsg = nsg;
return res;
}

View File

@@ -147,8 +147,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
@@ -219,8 +217,6 @@ struct ggml_metal_device_props {
bool use_shared_buffers;
bool supports_gpu_family_apple7;
int op_offload_min_batch_size;
};
ggml_metal_device_t ggml_metal_device_init(void);

View File

@@ -782,8 +782,6 @@ ggml_metal_device_t ggml_metal_device_init(void) {
dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
@@ -1025,11 +1023,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:

View File

@@ -78,7 +78,6 @@
#define FC_MUL_MM 700
#define FC_ROPE 800
#define FC_SSM_CONV 900
#define FC_COUNT_EQUAL 1000
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
@@ -895,25 +894,6 @@ typedef struct {
float step;
} ggml_metal_kargs_arange;
typedef struct {
int64_t val;
} ggml_metal_kargs_memset;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
} ggml_metal_kargs_count_equal;
typedef struct {
int32_t k0;
int32_t k1;

View File

@@ -448,11 +448,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
case GGML_OP_COUNT_EQUAL:
{
n_fuse = ggml_metal_op_count_equal(ctx, idx);
} break;
default:
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
GGML_ABORT("fatal error");
@@ -2181,11 +2177,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr;
// note: the non-vec kernel requires more extra memory, so always reserve for it
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
if (false) {
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
// note: always reserve the padding space to avoid graph reallocations
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
const bool has_kvpad = true;
@@ -4098,64 +4090,3 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
{
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
}
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_kargs_count_equal args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
};
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
const size_t smem = pipeline.smem;
const int nth = 32*pipeline.nsg;
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}

View File

@@ -87,7 +87,6 @@ int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}

View File

@@ -625,11 +625,14 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
}
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
const int min_batch_size = 32;
return (op->op == GGML_OP_MUL_MAT ||
op->op == GGML_OP_MUL_MAT_ID) &&
get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
GGML_UNUSED(op);
}
static ggml_backend_device_i ggml_backend_metal_device_i = {

View File

@@ -1790,7 +1790,6 @@ kernel void kernel_op_sum_f32(
return;
}
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
@@ -9148,7 +9147,6 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
@@ -9559,6 +9557,9 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9614,6 +9615,9 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9916,75 +9920,3 @@ kernel void kernel_opt_step_sgd_f32(
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;

View File

@@ -57,7 +57,6 @@ set(GGML_OPENCL_KERNELS
add
add_id
argsort
fill
clamp
cpy
cvt
@@ -69,7 +68,6 @@ set(GGML_OPENCL_KERNELS
get_rows
glu
group_norm
solve_tri
im2col_f32
im2col_f16
mean
@@ -122,8 +120,6 @@ set(GGML_OPENCL_KERNELS
tsembd
upscale
tanh
expm1
softplus
pad
repeat
mul_mat_f16_f32

View File

@@ -263,32 +263,6 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
// cl buffer wrapper
struct ggml_cl_buffer {
cl_mem buffer;
size_t size;
ggml_cl_buffer()
: buffer(nullptr), size(0) {}
~ggml_cl_buffer() {
if (buffer) {
CL_CHECK(clReleaseMemObject(buffer));
}
}
void allocate(cl_context context, size_t new_size) {
if (new_size > size) {
size = new_size;
if (buffer) {
CL_CHECK(clReleaseMemObject(buffer));
}
cl_int err;
CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
}
}
};
// Profiling
struct ProfilingInfo {
std::string op_name;
@@ -402,11 +376,6 @@ struct ggml_backend_opencl_context {
cl_context context;
cl_command_queue queue;
// prealloc buffers for transposing weights and activations
ggml_cl_buffer prealloc_quant_trans;
ggml_cl_buffer prealloc_scales_trans;
ggml_cl_buffer prealloc_act_trans;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -489,7 +458,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_fill;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
@@ -531,7 +499,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
cl_kernel kernel_solve_tri_f32;
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32;
cl_kernel kernel_sum_rows_f32;
@@ -539,10 +506,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32_nd;
cl_kernel kernel_tanh_f16_nd;
cl_kernel kernel_expm1_f32_nd;
cl_kernel kernel_expm1_f16_nd;
cl_kernel kernel_softplus_f32_nd;
cl_kernel kernel_softplus_f16_nd;
cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous;
@@ -675,6 +638,10 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_transpose_16_buf;
cl_kernel kernel_transpose_16_4x1;
cl_mem A_s_d_max; // max scale buffer size for transpose
cl_mem A_q_d_max; // max weight buffer size for transpose
cl_mem B_d_max; // max activation buffer size for transpose
// Gemm and Gemv related programs, kernels, etc
cl_program program_CL_gemm;
cl_program program_CL_gemv_general;
@@ -793,24 +760,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// fill
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "fill.cl.h"
};
#else
const std::string kernel_src = read_file("fill.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_fill = clCreateKernel(prog, "kernel_fill_f32", &err), err));
GGML_LOG_CONT(".");
CL_CHECK(clReleaseProgram(prog));
}
// clamp
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -953,23 +902,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// solve_tri_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "solve_tri.cl.h"
};
#else
const std::string kernel_src = read_file("solve_tri.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_solve_tri_f32 = clCreateKernel(prog, "kernel_solve_tri_f32", &err), err));
GGML_LOG_CONT(".");
CL_CHECK(clReleaseProgram(prog));
}
// im2col_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1821,56 +1753,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
}
}
// expm1
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "expm1.cl.h"
};
#else
const std::string kernel_src = read_file("expm1.cl");
#endif
cl_program prog;
if (!kernel_src.empty()) {
prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_expm1_f32_nd = nullptr;
backend_ctx->kernel_expm1_f16_nd = nullptr;
}
CL_CHECK(clReleaseProgram(prog));
}
// softplus
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "softplus.cl.h"
};
#else
const std::string kernel_src = read_file("softplus.cl");
#endif
cl_program prog;
if (!kernel_src.empty()) {
prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_softplus_f32_nd = nullptr;
backend_ctx->kernel_softplus_f16_nd = nullptr;
}
CL_CHECK(clReleaseProgram(prog));
}
// upscale
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2718,9 +2600,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
required_B_d_bytes, max_B_d_bytes);
}
backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);
backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);
backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);
CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
@@ -3180,12 +3062,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_TANH:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
case GGML_UNARY_OP_EXPM1:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
case GGML_UNARY_OP_SOFTPLUS:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
default:
return false;
}
@@ -3201,8 +3077,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_FILL:
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
@@ -3284,8 +3158,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
}
return true;
}
case GGML_OP_SOLVE_TRI:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
return true;
case GGML_OP_ARGSORT: {
@@ -3735,35 +3607,32 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// use sub_buffer of max buffer size instead
size_t q_size_bytes = K * M / 8 * sizeof(float);
backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);
cl_buffer_region region;
region.origin = 0;
region.size = q_size_bytes;
cl_mem qT_d = clCreateSubBuffer(
backend_ctx->prealloc_quant_trans.buffer,
backend_ctx->A_q_d_max,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&err);
// cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
CL_CHECK(err);
bool K_tile_trans = true;
if ((K / 32) % 4 != 0){
K_tile_trans =false;
}
size_t d_size_bytes = M * (K / 32) * 2;
backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);
region.origin = 0;
region.size = d_size_bytes;
cl_mem dT_d = clCreateSubBuffer(
backend_ctx->prealloc_scales_trans.buffer,
backend_ctx->A_s_d_max,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&err);
// cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
CL_CHECK(err);
// <----------------------------------------------------------------------------------> //
@@ -4367,8 +4236,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
}
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
*free = 0;
*total = 0;
*free = 1;
*total = 1;
GGML_UNUSED(dev);
}
@@ -5961,36 +5830,6 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src0);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offsetd = extrad->offset + dst->view_offs;
float v = 0.0f;
memcpy(&v, ((int32_t *) dst->op_params), sizeof(float));
const int64_t n = ggml_nelements(dst);
cl_kernel kernel = backend_ctx->kernel_fill;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n));
size_t local_work_size[1] = { 256 };
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
}
static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -6544,210 +6383,6 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_expm1_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
}
GGML_ASSERT(kernel != nullptr);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0];
const int ne11 = dst->ne[1];
const int ne12 = dst->ne[2];
const int ne13 = dst->ne[3];
const cl_ulong nb10 = dst->nb[0];
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
}
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_softplus_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
}
GGML_ASSERT(kernel != nullptr);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0];
const int ne11 = dst->ne[1];
const int ne12 = dst->ne[2];
const int ne13 = dst->ne[3];
const cl_ulong nb10 = dst->nb[0];
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
}
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -7760,10 +7395,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
region.origin = 0;
// Specify the size of the sub-buffer (divide by 2 for FP16)
region.size = K * (N + padding) * sizeof(float)/2;
backend_ctx->prealloc_act_trans.allocate(context, region.size);
B_d = clCreateSubBuffer(
backend_ctx->prealloc_act_trans.buffer,
backend_ctx->B_d_max,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
@@ -9494,72 +9127,6 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_solve_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_solve_tri_f32;
GGML_ASSERT(kernel != nullptr);
const int n = src0->ne[0];
const int k = src1->ne[0];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &k));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),&nb13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb3));
size_t global_work_size[3]= { (size_t)k, (size_t)dst->ne[2], (size_t)dst->ne[3]};
size_t local_work_size[] = {16, 4, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src1);
@@ -9987,18 +9554,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_tanh;
break;
case GGML_UNARY_OP_EXPM1:
if (!any_on_device) {
return false;
}
func = ggml_cl_expm1;
break;
case GGML_UNARY_OP_SOFTPLUS:
if (!any_on_device) {
return false;
}
func = ggml_cl_softplus;
break;
default:
return false;
} break;
@@ -10008,12 +9563,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_glu;
break;
case GGML_OP_FILL:
if (!any_on_device) {
return false;
}
func = ggml_cl_fill;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;
@@ -10125,12 +9674,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_rope;
break;
case GGML_OP_SOLVE_TRI:
if (!any_on_device) {
return false;
}
func = ggml_cl_solve_tri;
break;
case GGML_OP_IM2COL:
if (!any_on_device) {
return false;

View File

@@ -1,82 +0,0 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
}
}

Some files were not shown because too many files have changed in this diff Show More