ggml : implement GEGLU_ERF and GEGLU_QUICK ops (#14445)
This commit is contained in:
@@ -456,6 +456,8 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_geglu[2];
|
||||
vk_pipeline pipeline_reglu[2];
|
||||
vk_pipeline pipeline_swiglu[2];
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
@@ -2821,6 +2823,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_GLU(geglu)
|
||||
CREATE_GLU(reglu)
|
||||
CREATE_GLU(swiglu)
|
||||
CREATE_GLU(geglu_erf)
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -6575,6 +6579,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -8919,6 +8927,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -9166,6 +9176,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
break;
|
||||
default:
|
||||
@@ -9384,6 +9396,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
@@ -10194,6 +10208,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
|
||||
27
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
Normal file
27
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
Normal file
@@ -0,0 +1,27 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
const float p_erf = 0.3275911f;
|
||||
const float a1_erf = 0.254829592f;
|
||||
const float a2_erf = -0.284496736f;
|
||||
const float a3_erf = 1.421413741f;
|
||||
const float a4_erf = -1.453152027f;
|
||||
const float a5_erf = 1.061405429f;
|
||||
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
float op(float a, float b) {
|
||||
const float a_div_sqr2 = a * SQRT_2_INV;
|
||||
const float sign_x = sign(a_div_sqr2);
|
||||
const float x = abs(a_div_sqr2);
|
||||
const float t = 1.0f / (1.0f + p_erf * x);
|
||||
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
const float erf_approx = sign_x * y;
|
||||
|
||||
return 0.5f * a * (1.0f + erf_approx) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
11
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
Normal file
11
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
Normal file
@@ -0,0 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
float op(float a, float b) {
|
||||
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
@@ -593,6 +593,10 @@ void process_shaders() {
|
||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
Reference in New Issue
Block a user