llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@@ -513,26 +513,25 @@ typedef struct {
|
||||
typedef struct {
|
||||
int64_t d_state;
|
||||
int64_t d_inner;
|
||||
int64_t n_head;
|
||||
int64_t n_group;
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb10;
|
||||
uint64_t nb03;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb30;
|
||||
uint64_t nb31;
|
||||
uint64_t nb40;
|
||||
uint64_t nb41;
|
||||
uint64_t nb42;
|
||||
uint64_t nb50;
|
||||
uint64_t nb43;
|
||||
uint64_t nb51;
|
||||
uint64_t nb52;
|
||||
uint64_t nb53;
|
||||
} ggml_metal_kargs_ssm_scan;
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
@@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||
@@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
|
||||
struct ggml_tensor * src3 = node->src[3];
|
||||
struct ggml_tensor * src4 = node->src[4];
|
||||
struct ggml_tensor * src5 = node->src[5];
|
||||
struct ggml_tensor * src6 = node->src[6];
|
||||
|
||||
GGML_ASSERT(src3);
|
||||
GGML_ASSERT(src4);
|
||||
GGML_ASSERT(src5);
|
||||
GGML_ASSERT(src6);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
size_t offs_src6 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
||||
|
||||
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
||||
const int64_t ne30 = src3->ne[0];
|
||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||
|
||||
const uint64_t nb30 = src3->nb[0];
|
||||
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
||||
const uint64_t nb31 = src3->nb[1];
|
||||
|
||||
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
||||
const int64_t ne41 = src4->ne[1];
|
||||
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
||||
|
||||
const uint64_t nb40 = src4->nb[0];
|
||||
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
||||
const uint64_t nb41 = src4->nb[1];
|
||||
const uint64_t nb42 = src4->nb[2];
|
||||
const uint64_t nb43 = src4->nb[3];
|
||||
|
||||
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
||||
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
||||
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
||||
|
||||
const uint64_t nb50 = src5->nb[0];
|
||||
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
||||
const uint64_t nb51 = src5->nb[1];
|
||||
const uint64_t nb52 = src5->nb[2];
|
||||
const uint64_t nb53 = src5->nb[3];
|
||||
|
||||
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
||||
|
||||
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
||||
|
||||
const int64_t d_state = ne00;
|
||||
const int64_t d_inner = ne01;
|
||||
const int64_t n_seq_tokens = ne11;
|
||||
const int64_t n_seqs = ne02;
|
||||
const int64_t n_head = ne02;
|
||||
const int64_t n_group = ne41;
|
||||
const int64_t n_seq_tokens = ne12;
|
||||
const int64_t n_seqs = ne13;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
}
|
||||
|
||||
ggml_metal_kargs_ssm_scan args = {
|
||||
/*.d_state =*/ d_state,
|
||||
/*.d_inner =*/ d_inner,
|
||||
/*.d_state =*/ d_state,
|
||||
/*.d_inner =*/ d_inner,
|
||||
/*.n_head =*/ n_head,
|
||||
/*.n_group =*/ n_group,
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb30 =*/ nb30,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb40 =*/ nb40,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.nb50 =*/ nb50,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.nb43 =*/ nb43,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.nb53 =*/ nb53,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
||||
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
|
||||
@@ -1596,7 +1596,7 @@ kernel void kernel_ssm_conv_f32(
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
||||
kernel void kernel_ssm_scan_f32(
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
@@ -1604,46 +1604,119 @@ kernel void kernel_ssm_scan_f32(
|
||||
device const void * src3,
|
||||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i3 = tgpig.y;
|
||||
const int64_t i1 = 0;
|
||||
const int64_t ir = tgpig.x; // current head
|
||||
const int64_t i3 = tgpig.y; // current seq
|
||||
|
||||
const uint64_t nb00 = sizeof(float);
|
||||
const uint64_t nb10 = sizeof(float);
|
||||
const uint64_t nb20 = sizeof(float);
|
||||
|
||||
const int64_t nc = args.d_state;
|
||||
// const int64_t nr = args.d_inner;
|
||||
const int64_t nr = args.d_inner;
|
||||
const int64_t nh = args.n_head;
|
||||
const int64_t ng = args.n_group;
|
||||
const int64_t n_t = args.n_seq_tokens;
|
||||
// const int64_t n_s = args.n_seqs;
|
||||
|
||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
|
||||
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
|
||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
||||
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
|
||||
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
|
||||
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
|
||||
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
|
||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
|
||||
if (i2 > 0) {
|
||||
s0 = s;
|
||||
}
|
||||
|
||||
// i1 == 0
|
||||
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
float x_dt = x[0] * dt_soft_plus;
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
int64_t i = i0;
|
||||
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
}
|
||||
|
||||
y[0] = sumf;
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
}
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||
// TODO: optimize (e.g. by parallelizing over d_state)
|
||||
kernel void kernel_ssm_scan_f32_group(
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device const void * src2,
|
||||
device const void * src3,
|
||||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i1 = tgpig.x;
|
||||
const int64_t ir = tgpig.y; // current head
|
||||
const int64_t i3 = tgpig.z; // current seq
|
||||
|
||||
const uint64_t nb00 = sizeof(float);
|
||||
const uint64_t nb10 = sizeof(float);
|
||||
const uint64_t nb20 = sizeof(float);
|
||||
|
||||
const int64_t nc = args.d_state;
|
||||
const int64_t nr = args.d_inner;
|
||||
const int64_t nh = args.n_head;
|
||||
const int64_t ng = args.n_group;
|
||||
const int64_t n_t = args.n_seq_tokens;
|
||||
|
||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||
|
||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
const float dA = exp(dt_soft_plus * A[0]);
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
}
|
||||
|
||||
y[0] = sumf;
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user