vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (#15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against #15489, sync after clearing partial sums
This commit is contained in:
Jeff Bolz
2025-08-23 13:16:17 -05:00
committed by GitHub
parent b1afcab804
commit 611f419cff
7 changed files with 379 additions and 50 deletions

View File

@@ -102,9 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
struct ggml_backend_vk_context;
#define MAX_PARAMETER_COUNT 8
#define MAX_PARAMETER_COUNT 12
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
struct vk_pipeline_struct {
std::string name;
@@ -381,6 +381,9 @@ struct vk_device_struct {
bool subgroup_shuffle;
bool multi_add;
bool add_rms_fusion;
uint32_t partials_binding_alignment;
bool integer_dot_product;
bool subgroup_size_control;
@@ -460,9 +463,12 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_norepeat[2][2][2];
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_add_rms[2][2][2];
vk_pipeline pipeline_add_rms_norepeat[2][2][2];
// indexed by num_additional_fused_ops == num_adds - 1
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
vk_pipeline pipeline_add_id_f32;
@@ -486,6 +492,8 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_mul_f32;
vk_pipeline pipeline_rms_norm_partials_f32;
vk_pipeline pipeline_rms_norm_mul_partials_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
@@ -823,8 +831,13 @@ struct vk_op_multi_add_push_constants {
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
// strides for srcs+dst
uint32_t nb[8][4];
uint32_t nb[MAX_PARAMETER_COUNT][4];
uint32_t rms_partials;
};
// update multi_add.comp if this changes
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_add_id_push_constants {
uint32_t ne0;
@@ -1208,6 +1221,12 @@ class vk_perf_logger {
timings[name].push_back(time);
return;
}
if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
timings[name].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
@@ -1222,10 +1241,13 @@ struct ggml_backend_vk_context {
size_t semaphore_idx, event_idx;
ggml_vk_garbage_collector gc;
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
vk::Fence fence, almost_ready_fence;
bool almost_ready_fence_pending {};
// Set before op_add and unset after op_rms_norm to indicate that the add should
// write partial sums to accumulate the square of the vector components
bool do_add_rms_partials;
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
@@ -2987,8 +3009,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -3058,25 +3084,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
};
bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec) \
#define CREATE_BINARY(name, namemod, spec, bindings) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
CREATE_BINARY(add, , {0})
CREATE_BINARY(add, _norepeat, {1})
CREATE_BINARY(sub, , {0})
CREATE_BINARY(sub, _norepeat, {1})
CREATE_BINARY(mul, , {0})
CREATE_BINARY(mul, _norepeat, {1})
CREATE_BINARY(div, , {0})
CREATE_BINARY(div, _norepeat, {1})
CREATE_BINARY(add, , {0}, 4)
CREATE_BINARY(add, _norepeat, {1}, 4)
CREATE_BINARY(sub, , {0}, 3)
CREATE_BINARY(sub, _norepeat, {1}, 3)
CREATE_BINARY(mul, , {0}, 3)
CREATE_BINARY(mul, _norepeat, {1}, 3)
CREATE_BINARY(div, , {0}, 3)
CREATE_BINARY(div, _norepeat, {1}, 3)
CREATE_BINARY(add_rms, , {0}, 4)
CREATE_BINARY(add_rms, _norepeat, {1}, 4)
#undef CREATE_BINARY
if (device->multi_add) {
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
}
}
@@ -3944,6 +3973,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
device->add_rms_fusion = !device->disable_fusion &&
device->subgroup_add &&
device->vendor_id != VK_VENDOR_ID_INTEL;
device->partials_binding_alignment =
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
return device;
}
@@ -7080,7 +7115,7 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
return elements;
}
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -7109,10 +7144,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_ADD:
{
if (ctx->num_additional_fused_ops > 0) {
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
if (ctx->do_add_rms_partials) {
return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
} else {
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
}
}
if (ctx->do_add_rms_partials) {
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
} else {
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
case GGML_OP_SUB:
{
@@ -7235,7 +7279,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
if (ctx->do_add_rms_partials) {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
} else {
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
}
}
return nullptr;
case GGML_OP_RMS_NORM_BACK:
@@ -7748,7 +7796,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
} break;
case GGML_OP_RMS_NORM:
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
if (ctx->do_add_rms_partials) {
// Run one element per thread, 128 threads per workgroup
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
} else {
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
}
break;
case GGML_OP_SUM:
@@ -7897,7 +7950,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
if (op == GGML_OP_GLU) {
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
vk_subbuffer{ d_D, d_buf_offset, d_sz },
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
}, pc, elements);
} else if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
@@ -7998,7 +8060,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
uint32_t num_tensors = num_srcs + 1;
GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
tensors[0] = first_node->src[0];
tensors[1] = first_node->src[1];
@@ -8025,8 +8087,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
}
pc.rms_partials = ctx->do_add_rms_partials;
vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
if (pipeline == nullptr) {
std::cerr << "ggml_vulkan: Error: Missing multi_add";
@@ -8064,6 +8127,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
buf[i] = buf[0];
offset[i] = 0;
}
if (ctx->do_add_rms_partials) {
buf[num_tensors] = ctx->prealloc_add_rms_partials;
offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
}
std::array<uint32_t, 3> elements;
@@ -8076,6 +8143,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
elements = { ne, 1, 1 };
}
static_assert(MAX_PARAMETER_COUNT == 12);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
@@ -8086,6 +8154,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
}, pc, elements);
}
@@ -8100,7 +8172,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
0.0f, 0.0f, ctx->do_add_rms_partials,
}, dryrun);
}
@@ -8558,19 +8630,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
const uint32_t ne = (uint32_t)node->ne[0];
const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
const uint32_t num_partials = CEIL_DIV(ne, denom);
return num_partials;
}
static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
return num_bytes;
}
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f, 0,
op_params[0], 0.0f, (int32_t)param3,
}, dryrun);
if (ctx->do_add_rms_partials) {
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
ctx->do_add_rms_partials = false;
}
}
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9848,6 +9940,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
}
if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
// Resize buffer
if (ctx->prealloc_add_rms_partials != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
}
ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
}
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9904,10 +10004,23 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return false;
}
break;
case GGML_OP_ADD:
{
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
if (next_node_idx < cgraph->n_nodes &&
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
ctx->device->add_rms_fusion) {
if (dryrun) {
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
}
ctx->do_add_rms_partials = true;
}
} break;
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
@@ -10029,6 +10142,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// do the only thing needed for the dryrun.
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (node->op == GGML_OP_RMS_NORM) {
ctx->do_add_rms_partials = false;
}
return false;
}
default:
@@ -11098,6 +11214,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
}
ctx->prealloc_size_add_rms_partials = 0;
ctx->prealloc_size_add_rms_partials_offset = 0;
ctx->do_add_rms_partials = false;
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!ctx->device->disable_fusion) {
@@ -11166,6 +11286,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
if (ctx->prealloc_size_add_rms_partials) {
if (ctx->compute_ctx.expired()) {
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
compute_ctx = ctx->compute_ctx.lock();
}
// initialize partial sums to zero.
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
ggml_vk_sync_buffers(ctx, compute_ctx);
}
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).