llama : add high-throughput mode (#14363)
* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
||||
// validate input batch
|
||||
//
|
||||
|
||||
if (n_seq_max > LLAMA_MAX_SEQ) {
|
||||
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
||||
|
||||
// initialize the starting position for each sequence based on the positions in the memory
|
||||
llama_pos p0[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (!memory) {
|
||||
// if no memory -> start from 0
|
||||
p0[s] = 0;
|
||||
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
|
||||
// compute stats
|
||||
//
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_embd = n_embd;
|
||||
this->n_seq_max = n_seq_max;
|
||||
|
||||
// count the outputs in this batch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
|
||||
seq_set_map[cur].push_back(i);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
seq_idx[s] = seq_id_unq.size();
|
||||
seq_id_unq.push_back(s);
|
||||
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
|
||||
// consistency checks
|
||||
//
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
|
||||
}
|
||||
|
||||
if (memory) {
|
||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
||||
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
||||
if (seq_cpl[s0][s1]) {
|
||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
|
||||
//
|
||||
{
|
||||
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_set[s].set();
|
||||
}
|
||||
|
||||
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_pos[s] = -1;
|
||||
}
|
||||
|
||||
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
||||
ubatch.seq_id_unq.push_back(s);
|
||||
|
||||
@@ -48,6 +48,7 @@ public:
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
@@ -100,6 +101,7 @@ private:
|
||||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_seq_max;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
||||
@@ -98,10 +98,20 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
{
|
||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||
|
||||
if (!supports_set_rows && !cparams.kv_unified) {
|
||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||
cparams.kv_unified = true;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
@@ -112,6 +122,7 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
@@ -267,7 +278,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
@@ -300,7 +311,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -311,6 +322,10 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
||||
//
|
||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||
//
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
@@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
@@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
const int32_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = balloc->get_n_tokens();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
||||
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
@@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
@@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
|
||||
/*.no_perf =*/ true,
|
||||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
||||
@@ -11,8 +11,8 @@ struct llama_cparams {
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
@@ -33,6 +33,7 @@ struct llama_cparams {
|
||||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
bool kv_unified;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
||||
@@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
// split the batch into streams if needed
|
||||
const auto n_stream = k->ne[3];
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
const auto n_kv = k->ne[1];
|
||||
const auto n_kv = k->ne[1];
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
@@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
// recombine streams
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
@@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||
assert(ubatch.equal_seqs == false);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
@@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
||||
@@ -255,10 +255,10 @@ public:
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
@@ -289,14 +289,14 @@ public:
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
||||
const uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_k_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_v_gqa_variable() const {
|
||||
const uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_v_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa_max() const {
|
||||
uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_k_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa_max() const {
|
||||
uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_v_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
|
||||
@@ -191,6 +191,14 @@ struct llama_hparams {
|
||||
// dimension of value embeddings across all k-v heads
|
||||
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
||||
|
||||
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
|
||||
bool is_n_embd_k_gqa_variable() const;
|
||||
bool is_n_embd_v_gqa_variable() const;
|
||||
|
||||
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
|
||||
uint32_t n_embd_k_gqa_max() const;
|
||||
uint32_t n_embd_v_gqa_max() const;
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_r() const;
|
||||
|
||||
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
if (!unified) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch, false);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
||||
@@ -20,6 +20,7 @@ public:
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
@@ -68,6 +69,8 @@ public:
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,16 +35,50 @@ public:
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
return ssrc.empty();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ssrc;
|
||||
std::vector<uint32_t> sdst;
|
||||
};
|
||||
|
||||
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||
struct slot_info {
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
idx_vec_t idxs;
|
||||
// number of streams: ns = s1 - s0 + 1
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> strm; // [ns]
|
||||
std::vector<idx_vec_t> idxs; // [ns]
|
||||
|
||||
uint32_t head() const {
|
||||
return idxs.at(0);
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
GGML_ASSERT(!idxs[0].empty());
|
||||
|
||||
return idxs[0][0];
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
strm.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == strm.size());
|
||||
GGML_ASSERT(!idxs.empty());
|
||||
|
||||
return idxs[0].size();
|
||||
}
|
||||
|
||||
size_t n_stream() const {
|
||||
return strm.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
@@ -54,9 +88,6 @@ public:
|
||||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
//std::vector<idx_vec_t> seq_idxs;
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
@@ -68,6 +99,7 @@ public:
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
@@ -111,7 +143,8 @@ public:
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_n_stream() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
@@ -122,8 +155,8 @@ public:
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
@@ -137,7 +170,7 @@ public:
|
||||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
@@ -157,8 +190,9 @@ public:
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
@@ -172,15 +206,15 @@ private:
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_stream = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
@@ -200,7 +234,17 @@ private:
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
||||
// pending stream copies that will be applied during the next update
|
||||
stream_copy_info sc_info;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
@@ -237,18 +281,25 @@ private:
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
|
||||
};
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
@@ -262,7 +313,8 @@ public:
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
@@ -320,6 +372,8 @@ private:
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
@@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
|
||||
@@ -16647,7 +16647,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
} else {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
if (!cparams.kv_unified) {
|
||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
|
||||
} else {
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
@@ -16661,7 +16672,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
@@ -16675,7 +16687,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
|
||||
Reference in New Issue
Block a user