vulkan: 64-bit im2col (#16135)

* vulkan: 64-bit im2col

Add variants of the im2col shaders that use buffer_device_address/buffer_reference,
and use 64-bit address calculations. This is needed for large convolutions used in
stable-diffusion.cpp.

* fix validation error for large im2col
This commit is contained in:
Jeff Bolz
2025-09-28 01:38:37 -05:00
committed by GitHub
parent 6a2c6145a0
commit d8359f5fde
6 changed files with 117 additions and 26 deletions

View File

@@ -408,6 +408,8 @@ struct vk_device_struct {
bool subgroup_ballot;
bool subgroup_clustered;
bool multi_add;
bool shader_int64;
bool buffer_device_address;
bool add_rms_fusion;
uint32_t partials_binding_alignment;
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
vk::MemoryPropertyFlags memory_property_flags;
void * ptr;
size_t size = 0;
vk::DeviceAddress bda_addr {};
vk_device device;
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
};
struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
uint32_t IC;
uint32_t IW; uint32_t IH;
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
};
struct vk_op_im2col_3d_push_constants {
uint64_t dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
return buf;
}
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
vk::MemoryAllocateFlags mem_flags {};
if (device->buffer_device_address) {
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
}
vk::BufferCreateInfo buffer_create_info{
vk::BufferCreateFlags(),
size,
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
usage_flags,
vk::SharingMode::eExclusive,
0,
nullptr,
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->memory_property_flags = req_flags;
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
break;
} catch (const vk::SystemError& e) {
// loop and retry
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->device = device;
buf->size = size;
if (device->buffer_device_address) {
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
buf->bda_addr = device->device.getBufferAddress(addressInfo);
}
#ifdef GGML_VULKAN_MEMORY_DEBUG
device->memory_logger->log_allocation(buf, size);
#endif
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
}
if (device->shader_int64 && device->buffer_device_address) {
IM2COL(_bda)
} else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
IM2COL()
}
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->vendor_id != VK_VENDOR_ID_INTEL &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
device->shader_int64 = device_features2.features.shaderInt64;
device->buffer_device_address = vk12_features.bufferDeviceAddress;
if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
d_sz = 1;
}
// im2col uses only src1 and dst buffers
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t pelements = OW * KW * KH;
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
const int64_t OH = ne2;
const int64_t OW = ne1;
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
vk_op_im2col_3d_push_constants pc {};
pc.dst_addr = dst_addr;
pc.nb10 = nb10 / ggml_type_size(src1->type);
pc.nb11 = nb11 / ggml_type_size(src1->type);
pc.nb12 = nb12 / ggml_type_size(src1->type);