metal: SSM_SCAN performance (#14743)
* feat: Add s_off as a parameter in the args struct This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update logic to correctly do the multi-layer parallel sum Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Correctly size the shared memory bufer and assert expected size relationships Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Compute block offsets once rather than once per token Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Use local variable for state recursion Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Use a secondary simd_sum instead of a for loop Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add assertion and comment about relationship between simd size and num simd groups Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Parallelize of d_state for mamba-1 Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Parallel sum in SSM_CONV Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Revert "feat: Parallel sum in SSM_CONV" After discussion with @compilade, the size of the parallelism here is not worth the cost in complexity or overhead of the parallel for. https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357 This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Simplify shared memory sizing Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-Authored-By: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
|
||||
/*.n_group =*/ n_group,
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
@@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||
|
||||
// One shared memory bucket for each simd group in the threadgroup
|
||||
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
if (d_state >= 32) {
|
||||
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
||||
const int64_t shmem_size = 32;
|
||||
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
||||
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
||||
}
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
||||
Reference in New Issue
Block a user