vulkan: fix rms_norm_mul to handle broadcasting dim0 (#14817)

This commit is contained in:
Jeff Bolz
2025-07-22 10:35:21 -05:00
committed by GitHub
parent d4d1522b20
commit 84712b6043
2 changed files with 9 additions and 3 deletions

View File

@@ -10248,7 +10248,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) {
!ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
// rms_norm shader assumes contiguous rows