vulkan: Add bfloat16 support (#12554)

* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O
This commit is contained in:
Jeff Bolz
2025-05-01 13:49:39 -05:00
committed by GitHub
parent fc727bcdd5
commit 79f26e9e12
13 changed files with 368 additions and 67 deletions

View File

@@ -18,7 +18,11 @@ void main() {
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@@ -31,7 +35,10 @@ void main() {
continue;
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];