llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
@@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||
@@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
|
||||
struct ggml_tensor * src3 = node->src[3];
|
||||
struct ggml_tensor * src4 = node->src[4];
|
||||
struct ggml_tensor * src5 = node->src[5];
|
||||
struct ggml_tensor * src6 = node->src[6];
|
||||
|
||||
GGML_ASSERT(src3);
|
||||
GGML_ASSERT(src4);
|
||||
GGML_ASSERT(src5);
|
||||
GGML_ASSERT(src6);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
||||
|
||||
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
||||
const int64_t ne30 = src3->ne[0];
|
||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||
|
||||
const uint64_t nb30 = src3->nb[0];
|
||||
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
||||
const uint64_t nb31 = src3->nb[1];
|
||||
|
||||
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
||||
const int64_t ne41 = src4->ne[1];
|
||||
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
||||
|
||||
const uint64_t nb40 = src4->nb[0];
|
||||
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
||||
const uint64_t nb41 = src4->nb[1];
|
||||
const uint64_t nb42 = src4->nb[2];
|
||||
const uint64_t nb43 = src4->nb[3];
|
||||
|
||||
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
||||
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
||||
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
||||
|
||||
const uint64_t nb50 = src5->nb[0];
|
||||
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
||||
const uint64_t nb51 = src5->nb[1];
|
||||
const uint64_t nb52 = src5->nb[2];
|
||||
const uint64_t nb53 = src5->nb[3];
|
||||
|
||||
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
||||
|
||||
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
||||
|
||||
const int64_t d_state = ne00;
|
||||
const int64_t d_inner = ne01;
|
||||
const int64_t n_seq_tokens = ne11;
|
||||
const int64_t n_seqs = ne02;
|
||||
const int64_t n_head = ne02;
|
||||
const int64_t n_group = ne41;
|
||||
const int64_t n_seq_tokens = ne12;
|
||||
const int64_t n_seqs = ne13;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
}
|
||||
|
||||
ggml_metal_kargs_ssm_scan args = {
|
||||
/*.d_state =*/ d_state,
|
||||
/*.d_inner =*/ d_inner,
|
||||
/*.d_state =*/ d_state,
|
||||
/*.d_inner =*/ d_inner,
|
||||
/*.n_head =*/ n_head,
|
||||
/*.n_group =*/ n_group,
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb30 =*/ nb30,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb40 =*/ nb40,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.nb50 =*/ nb50,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.nb43 =*/ nb43,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.nb53 =*/ nb53,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user