batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
This commit is contained in:
@@ -18,8 +18,8 @@ struct llama_ubatch {
|
||||
llama_token * token; // [n_tokens]
|
||||
float * embd; // [n_embd, n_tokens]
|
||||
llama_pos * pos; // [n_tokens]
|
||||
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
|
||||
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
|
||||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
};
|
||||
|
||||
@@ -78,15 +78,28 @@ struct llama_sbatch {
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
class llama_batch_allocr {
|
||||
public:
|
||||
llama_batch_allocr();
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
uint32_t get_n_outputs() const;
|
||||
|
||||
private:
|
||||
void clear();
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> output;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user