2025-01-03 10:18:53 +02:00
# include "llama-batch.h"
2025-06-13 13:47:55 +03:00
# include "llama-impl.h"
# include "llama-vocab.h"
2025-06-15 09:18:37 +03:00
# include "llama-memory.h"
2025-06-13 13:47:55 +03:00
2025-05-21 15:11:13 +03:00
# include <cassert>
2025-01-03 10:18:53 +02:00
# include <cstring>
# include <algorithm>
2025-06-13 18:35:00 +03:00
# include <sstream>
2025-01-03 10:18:53 +02:00
2025-06-20 10:14:14 +03:00
llama_batch_allocr : : llama_batch_allocr ( uint32_t n_pos_per_embd ) : n_pos_per_embd ( n_pos_per_embd ) {
2025-06-13 18:35:00 +03:00
const char * LLAMA_BATCH_DEBUG = getenv ( " LLAMA_BATCH_DEBUG " ) ;
debug = LLAMA_BATCH_DEBUG ? atoi ( LLAMA_BATCH_DEBUG ) : 0 ;
2025-06-15 09:18:37 +03:00
2025-06-15 10:08:58 +03:00
seq_pos . resize ( LLAMA_MAX_SEQ ) ;
seq_cpl . resize ( LLAMA_MAX_SEQ ) ;
2025-06-15 09:18:37 +03:00
for ( auto & cur : seq_cpl ) {
2025-06-15 10:08:58 +03:00
cur . resize ( LLAMA_MAX_SEQ ) ;
2025-06-15 09:18:37 +03:00
}
2025-06-20 10:14:14 +03:00
seq_idx . resize ( LLAMA_MAX_SEQ , - 1 ) ;
2025-06-13 18:35:00 +03:00
}
2025-06-13 13:47:55 +03:00
2025-06-15 09:18:37 +03:00
bool llama_batch_allocr : : init (
const llama_batch & batch_inp ,
const llama_vocab & vocab ,
2025-06-16 14:14:00 +03:00
const llama_memory_i * memory ,
2025-06-20 10:14:14 +03:00
uint32_t n_embd ,
2025-07-16 16:35:42 +03:00
uint32_t n_seq_max ,
2025-06-20 10:14:14 +03:00
bool output_all ) {
2025-06-13 13:47:55 +03:00
clear ( ) ;
batch = batch_inp ;
2025-06-20 10:14:14 +03:00
this - > vocab = & vocab ;
2025-01-03 10:18:53 +02:00
GGML_ASSERT ( batch . n_tokens > 0 ) ;
2025-06-13 13:47:55 +03:00
2025-06-15 09:18:37 +03:00
//
// validate input batch
//
2025-06-13 13:47:55 +03:00
2025-07-16 16:35:42 +03:00
if ( n_seq_max > LLAMA_MAX_SEQ ) {
LLAMA_LOG_ERROR ( " %s: n_seq_max = %d > %d \n " , __func__ , n_seq_max , LLAMA_MAX_SEQ ) ;
return false ;
}
2025-06-13 13:47:55 +03:00
if ( batch . token ) {
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
if ( batch . token [ i ] < 0 | | ( uint32_t ) batch . token [ i ] > = vocab . n_tokens ( ) ) {
LLAMA_LOG_ERROR ( " %s: invalid token[%d] = %d \n " , __func__ , i , batch . token [ i ] ) ;
return false ;
}
}
}
if ( batch . seq_id ) {
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
for ( int32_t s = 0 ; s < batch . n_seq_id [ i ] ; + + s ) {
2025-07-16 16:35:42 +03:00
if ( batch . seq_id & & ( batch . seq_id [ i ] [ s ] < 0 | | batch . seq_id [ i ] [ s ] > = ( llama_seq_id ) n_seq_max ) ) {
2025-07-30 15:12:02 +03:00
LLAMA_LOG_ERROR ( " %s: invalid seq_id[%d][%d] = %d >= %d \n " , __func__ , i , s , batch . seq_id [ i ] [ s ] , ( llama_seq_id ) n_seq_max ) ;
2025-06-13 13:47:55 +03:00
return false ;
}
}
}
}
2025-06-15 09:18:37 +03:00
//
// auto-generate missing fields
//
2025-06-13 13:47:55 +03:00
2025-01-03 10:18:53 +02:00
if ( ! batch . n_seq_id ) {
n_seq_id . resize ( batch . n_tokens ) ;
for ( int32_t i = 0 ; i < batch . n_tokens ; i + + ) {
n_seq_id [ i ] = seq_id_0 . size ( ) ;
}
batch . n_seq_id = n_seq_id . data ( ) ;
}
2025-06-13 13:47:55 +03:00
2025-01-03 10:18:53 +02:00
if ( ! batch . seq_id ) {
seq_id . resize ( batch . n_tokens + 1 ) ;
seq_id [ batch . n_tokens ] = NULL ;
for ( int32_t i = 0 ; i < batch . n_tokens ; i + + ) {
seq_id [ i ] = seq_id_0 . data ( ) ;
}
batch . seq_id = seq_id . data ( ) ;
}
2025-06-13 13:47:55 +03:00
2025-06-15 09:18:37 +03:00
if ( ! batch . pos ) {
pos . resize ( batch . n_tokens ) ;
// initialize the starting position for each sequence based on the positions in the memory
2025-06-15 10:08:58 +03:00
llama_pos p0 [ LLAMA_MAX_SEQ ] ;
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-15 09:18:37 +03:00
if ( ! memory ) {
2025-06-20 10:14:14 +03:00
// if no memory -> start from 0
2025-06-15 09:18:37 +03:00
p0 [ s ] = 0 ;
} else {
p0 [ s ] = memory - > seq_pos_max ( s ) + 1 ;
}
}
for ( int32_t i = 0 ; i < batch . n_tokens ; i + + ) {
const llama_seq_id seq_id = batch . seq_id [ i ] [ 0 ] ;
pos [ i ] = p0 [ seq_id ] ;
2025-06-20 10:14:14 +03:00
// update the starting position for all sequences that are assigned to the this token
2025-06-15 09:18:37 +03:00
for ( int32_t s = 0 ; s < batch . n_seq_id [ i ] ; + + s ) {
2025-06-20 10:14:14 +03:00
const llama_seq_id seq_id = batch . seq_id [ i ] [ s ] ;
p0 [ seq_id ] = pos [ i ] + 1 ;
2025-06-15 09:18:37 +03:00
}
}
batch . pos = pos . data ( ) ;
}
2025-01-03 10:18:53 +02:00
if ( ! batch . logits ) {
2025-06-20 10:14:14 +03:00
if ( output_all ) {
2025-06-16 14:14:00 +03:00
// return the output for all tokens
output . resize ( batch . n_tokens , true ) ;
} else {
// return the output only for the last token
output . resize ( batch . n_tokens , false ) ;
output [ output . size ( ) - 1 ] = true ;
}
2025-06-12 11:50:01 +03:00
batch . logits = output . data ( ) ;
2025-06-20 10:14:14 +03:00
} else if ( output_all ) {
2025-06-16 14:14:00 +03:00
bool warn = false ;
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
if ( batch . logits [ i ] = = 0 ) {
warn = true ;
}
}
if ( warn ) {
LLAMA_LOG_WARN ( " %s: embeddings required but some input tokens were not marked as outputs -> overriding \n " , __func__ ) ;
output . resize ( batch . n_tokens , true ) ;
batch . logits = output . data ( ) ;
}
2025-01-03 10:18:53 +02:00
}
2025-06-13 13:47:55 +03:00
2025-06-15 09:18:37 +03:00
//
// compute stats
//
2025-07-16 16:35:42 +03:00
this - > n_embd = n_embd ;
this - > n_seq_max = n_seq_max ;
2025-06-20 10:14:14 +03:00
// count the outputs in this batch
2025-06-13 13:47:55 +03:00
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
n_outputs + = batch . logits [ i ] ! = 0 ;
}
2025-07-17 09:45:54 +03:00
has_cpl = false ;
2025-06-15 09:18:37 +03:00
// determine coupled sequences
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
2025-06-20 10:14:14 +03:00
const llama_seq_id s0 = batch . seq_id [ i ] [ 0 ] ;
2025-06-15 09:18:37 +03:00
for ( int32_t s = 0 ; s < batch . n_seq_id [ i ] ; + + s ) {
2025-06-20 10:14:14 +03:00
const llama_seq_id s1 = batch . seq_id [ i ] [ s ] ;
2025-06-15 09:18:37 +03:00
2025-06-20 10:14:14 +03:00
seq_pos [ s1 ] . insert ( batch . pos [ i ] ) ;
2025-06-15 09:18:37 +03:00
2025-06-20 10:14:14 +03:00
if ( s > 0 ) {
2025-06-15 09:18:37 +03:00
// mark that sequence s1 is coupled to s0
seq_cpl [ s1 ] [ s0 ] = true ;
2025-06-20 10:14:14 +03:00
// note: tracking the other way around is not necessary for now
2025-06-15 09:18:37 +03:00
//seq_cpl[s0][s1] = true;
2025-07-04 09:08:59 +03:00
has_cpl = true ;
2025-06-15 09:18:37 +03:00
}
}
}
2025-06-20 10:14:14 +03:00
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
{
seq_set_t seq_set_unq ;
2025-06-13 18:35:00 +03:00
2025-06-20 10:14:14 +03:00
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
seq_set_t cur ;
for ( int32_t s = 0 ; s < batch . n_seq_id [ i ] ; + + s ) {
const llama_seq_id seq_id = batch . seq_id [ i ] [ s ] ;
cur . set ( seq_id ) ;
seq_set_unq . set ( seq_id ) ;
2025-06-13 18:35:00 +03:00
}
2025-06-20 10:14:14 +03:00
seq_set . push_back ( cur ) ;
seq_set_map [ cur ] . push_back ( i ) ;
}
2025-06-13 18:35:00 +03:00
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-20 10:14:14 +03:00
if ( seq_set_unq . test ( s ) ) {
seq_idx [ s ] = seq_id_unq . size ( ) ;
seq_id_unq . push_back ( s ) ;
}
}
}
2025-06-13 18:35:00 +03:00
2025-06-20 10:14:14 +03:00
if ( debug > 0 ) {
LLAMA_LOG_DEBUG ( " %s: input batch info: \n " , __func__ ) ;
2025-06-13 18:35:00 +03:00
2025-06-20 10:14:14 +03:00
llama_ubatch ubatch {
2025-07-17 19:08:33 +03:00
/*.b_equal_seqs =*/ false ,
2025-06-20 10:14:14 +03:00
/*.n_tokens =*/ ( uint32_t ) batch . n_tokens ,
/*.n_seq_tokens =*/ ( uint32_t ) 1 ,
/*.n_seqs =*/ ( uint32_t ) batch . n_tokens ,
/*.n_seqs_unq =*/ ( uint32_t ) this - > seq_id_unq . size ( ) ,
/*.token =*/ batch . token ,
/*.embd =*/ batch . embd ,
/*.pos =*/ batch . pos ,
/*.n_seq_id =*/ batch . n_seq_id ,
/*.seq_id =*/ batch . seq_id ,
/*.seq_id_unq =*/ this - > seq_id_unq . data ( ) ,
/*.seq_idx =*/ this - > seq_idx . data ( ) ,
/*.output =*/ batch . logits ,
2025-07-17 19:08:33 +03:00
/*.data =*/ { } ,
2025-06-20 10:14:14 +03:00
} ;
ubatch_print ( ubatch , debug ) ;
LLAMA_LOG_DEBUG ( " %s: seq = [ \n " , __func__ ) ;
for ( int s0 = 0 ; s0 < ( int ) seq_pos . size ( ) ; + + s0 ) {
if ( seq_pos [ s0 ] . empty ( ) ) {
continue ;
2025-06-13 18:35:00 +03:00
}
2025-06-15 09:18:37 +03:00
2025-06-20 10:14:14 +03:00
std : : stringstream ss ;
for ( int s1 = 0 ; s1 < ( int ) seq_cpl [ s0 ] . size ( ) ; + + s1 ) {
if ( seq_cpl [ s0 ] [ s1 ] ) {
ss < < s1 < < " " ;
2025-06-15 09:18:37 +03:00
}
}
2025-06-20 10:14:14 +03:00
LLAMA_LOG_DEBUG ( " %s: %4d: pos = [%4d, %4d], cpl = %s \n " ,
__func__ , s0 , seq_pos_min ( s0 ) , seq_pos_max ( s0 ) , ss . str ( ) . empty ( ) ? " - " : ss . str ( ) . c_str ( ) ) ;
2025-06-15 09:18:37 +03:00
}
2025-06-20 10:14:14 +03:00
LLAMA_LOG_DEBUG ( " %s: ] \n " , __func__ ) ;
2025-06-15 09:18:37 +03:00
}
//
// consistency checks
//
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-15 09:18:37 +03:00
if ( seq_pos [ s ] . empty ( ) ) {
continue ;
}
2025-06-24 18:26:30 +03:00
const llama_pos p0 = memory ? memory - > seq_pos_max ( s ) : - 1 ;
if ( p0 > = 0 ) {
2025-06-23 12:27:35 +03:00
bool ok = true ;
2025-06-20 10:14:14 +03:00
if ( batch . token ) {
2025-06-24 18:26:30 +03:00
if ( seq_pos_min ( s ) ! = p0 + 1 ) {
2025-06-23 12:27:35 +03:00
ok = false ;
2025-06-20 10:14:14 +03:00
}
} else {
assert ( batch . embd ) ;
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
2025-06-24 18:26:30 +03:00
if ( seq_pos_min ( s ) ! = p0 & & seq_pos_min ( s ) ! = p0 + 1 ) {
2025-06-23 12:27:35 +03:00
ok = false ;
2025-06-20 10:14:14 +03:00
}
}
2025-06-23 12:27:35 +03:00
if ( ! ok ) {
LLAMA_LOG_ERROR (
" %s: the tokens of sequence %d in the input batch have inconsistent sequence positions: \n "
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d \n "
" - the tokens for sequence %d in the input batch have a starting position of Y = %d \n "
" it is required that the sequence positions remain consecutive: Y = X + 1 \n " ,
2025-06-24 18:26:30 +03:00
__func__ , s , s , p0 , s , seq_pos_min ( s ) ) ;
2025-06-23 12:27:35 +03:00
return false ;
}
2025-06-15 09:18:37 +03:00
}
if ( seq_pos_max ( s ) - seq_pos_min ( s ) + 1 > ( int ) seq_pos [ s ] . size ( ) ) {
LLAMA_LOG_ERROR ( " %s: sequence %d positions are not continuous \n " , __func__ , s ) ;
return false ;
}
}
if ( memory ) {
2025-07-16 16:35:42 +03:00
for ( uint32_t s0 = 0 ; s0 < n_seq_max ; + + s0 ) {
for ( uint32_t s1 = 0 ; s1 < n_seq_max ; + + s1 ) {
2025-06-15 09:18:37 +03:00
if ( seq_cpl [ s0 ] [ s1 ] ) {
if ( memory - > seq_pos_min ( s0 ) ! = memory - > seq_pos_min ( s1 ) | |
memory - > seq_pos_max ( s0 ) ! = memory - > seq_pos_max ( s1 ) ) {
LLAMA_LOG_ERROR ( " %s: sequence %d is coupled to %d in the input batch, but have divereged \n " , __func__ , s0 , s1 ) ;
return false ;
}
}
}
2025-06-13 18:35:00 +03:00
}
}
2025-06-20 10:14:14 +03:00
// disallow partial sequence sub-sets:
//
// invalid: x
// i: 0 1 2 ...
// ---------------------------------------
// seq_id[i][0]: 0 0 1
// seq_id[i][1]: 1 1 2
// seq_id[i][2]: 2
//
// disallow decreasing sequence positions:
//
// invalid: x
// i: 0 1 2 3 4 5 6 ...
// ---------------------------------------
// pos[i]: 4 5 0 1 6 2 3
// seq_id[i][0]: 0 0 1 1 0 1 0
//
{
seq_set_t cur_seq_set [ LLAMA_MAX_SEQ ] ;
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-20 10:14:14 +03:00
cur_seq_set [ s ] . set ( ) ;
}
llama_pos cur_seq_pos [ LLAMA_MAX_SEQ ] ;
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-20 10:14:14 +03:00
cur_seq_pos [ s ] = - 1 ;
}
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
const llama_pos pos = batch . pos [ i ] ;
for ( int32_t s = 0 ; s < batch . n_seq_id [ i ] ; + + s ) {
const llama_seq_id seq_id = batch . seq_id [ i ] [ s ] ;
cur_seq_set [ seq_id ] & = seq_set [ i ] ;
if ( cur_seq_set [ seq_id ] . none ( ) ) {
LLAMA_LOG_ERROR ( " %s: sequence %d belongs to incompatible sequence sets (not allowed) \n " , __func__ , seq_id ) ;
return false ;
}
if ( pos < cur_seq_pos [ seq_id ] ) {
LLAMA_LOG_ERROR ( " %s: sequence %d positions are decreasing (not allowed) \n " , __func__ , seq_id ) ;
return false ;
}
}
}
}
split_reset ( ) ;
2025-06-13 13:47:55 +03:00
return true ;
}
2025-06-20 10:14:14 +03:00
llama_ubatch llama_batch_allocr : : ubatch_reserve ( uint32_t n_seq_tokens , uint32_t n_seqs ) {
const uint32_t n_tokens = n_seq_tokens * n_seqs ;
clear ( ) ;
split_reset ( ) ;
2025-07-17 19:08:33 +03:00
auto udata = std : : make_shared < llama_ubatch : : data_t > ( ) ;
2025-06-20 10:14:14 +03:00
2025-07-17 19:08:33 +03:00
udata - > token . resize ( n_tokens ) ;
udata - > embd . clear ( ) ;
udata - > pos . resize ( n_tokens ) ;
udata - > n_seq_id . resize ( n_tokens ) ;
udata - > seq_id . resize ( n_tokens ) ;
udata - > seq_id_unq . resize ( 0 ) ;
udata - > seq_idx . resize ( LLAMA_MAX_SEQ , - 1 ) ;
udata - > output . resize ( n_tokens ) ;
2025-06-20 10:14:14 +03:00
for ( uint32_t s = 0 ; s < n_seqs ; + + s ) {
2025-07-17 19:08:33 +03:00
udata - > seq_idx [ s ] = s ;
udata - > seq_id_unq . push_back ( s ) ;
2025-06-20 10:14:14 +03:00
}
llama_ubatch res {
2025-07-17 19:08:33 +03:00
/*.b_equal_seqs =*/ true ,
2025-06-20 10:14:14 +03:00
/*.n_tokens =*/ n_tokens ,
/*.n_seq_tokens =*/ n_seq_tokens ,
/*.n_seqs =*/ n_seqs ,
/*.n_seqs_unq =*/ n_seqs ,
2025-07-17 19:08:33 +03:00
/*.token =*/ udata - > token . data ( ) ,
2025-06-20 10:14:14 +03:00
/*.embd =*/ nullptr ,
2025-07-17 19:08:33 +03:00
/*.pos =*/ udata - > pos . data ( ) ,
/*.n_seq_id =*/ udata - > n_seq_id . data ( ) ,
/*.seq_id =*/ udata - > seq_id . data ( ) ,
/*.seq_id_unq =*/ udata - > seq_id_unq . data ( ) ,
/*.seq_idx =*/ udata - > seq_idx . data ( ) ,
/*.output =*/ udata - > output . data ( ) ,
/*.data =*/ std : : move ( udata ) ,
2025-06-20 10:14:14 +03:00
} ;
return res ;
}
2025-06-13 13:47:55 +03:00
const llama_batch & llama_batch_allocr : : get_batch ( ) const {
return batch ;
}
2025-06-20 10:14:14 +03:00
uint32_t llama_batch_allocr : : get_n_tokens ( ) const {
return batch . n_tokens ;
}
2025-06-13 13:47:55 +03:00
uint32_t llama_batch_allocr : : get_n_outputs ( ) const {
return n_outputs ;
}
2025-07-04 09:04:59 +03:00
uint32_t llama_batch_allocr : : get_n_used ( ) const {
return n_used ;
}
2025-06-20 10:14:14 +03:00
std : : vector < int32_t > & llama_batch_allocr : : get_out_ids ( ) {
return out_ids ;
}
2025-06-15 09:18:37 +03:00
llama_pos llama_batch_allocr : : seq_pos_min ( llama_seq_id seq_id ) const {
return seq_pos [ seq_id ] . empty ( ) ? - 1 : * seq_pos [ seq_id ] . begin ( ) ;
}
llama_pos llama_batch_allocr : : seq_pos_max ( llama_seq_id seq_id ) const {
return seq_pos [ seq_id ] . empty ( ) ? - 1 : * seq_pos [ seq_id ] . rbegin ( ) ;
}
2025-06-20 10:14:14 +03:00
void llama_batch_allocr : : split_reset ( ) {
out_ids . clear ( ) ;
2025-07-04 09:04:59 +03:00
n_used = 0 ;
2025-06-20 10:14:14 +03:00
used . clear ( ) ;
used . resize ( get_n_tokens ( ) , false ) ;
}
llama_ubatch llama_batch_allocr : : split_simple ( uint32_t n_ubatch ) {
// find the first unused token
uint32_t cur_idx = 0 ;
while ( cur_idx < used . size ( ) & & used [ cur_idx ] ) {
+ + cur_idx ;
}
// we are done
if ( cur_idx > = used . size ( ) ) {
return { } ;
}
std : : vector < int32_t > idxs ;
while ( true ) {
idxs . push_back ( cur_idx ) ;
used [ cur_idx ] = true ;
2025-07-04 09:04:59 +03:00
+ + n_used ;
2025-06-20 10:14:14 +03:00
+ + cur_idx ;
if ( cur_idx > = used . size ( ) ) {
break ;
}
if ( idxs . size ( ) > = n_ubatch ) {
break ;
}
}
return ubatch_add ( idxs , idxs . size ( ) , false ) ;
}
2025-07-04 09:08:59 +03:00
llama_ubatch llama_batch_allocr : : split_equal ( uint32_t n_ubatch , bool sequential ) {
if ( sequential & & has_cpl ) {
2025-08-14 20:03:30 +09:00
LLAMA_LOG_ERROR ( " %s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag) \n " , __func__ ) ;
2025-07-04 09:08:59 +03:00
return { } ;
}
2025-06-20 10:14:14 +03:00
std : : vector < seq_set_t > cur_seq_set ;
2025-07-04 09:08:59 +03:00
llama_seq_id last_seq_id = - 1 ;
2025-06-20 10:14:14 +03:00
// determine the non-overlapping sequence sets participating in this ubatch
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
if ( used [ i ] ) {
continue ;
}
bool add = true ;
for ( uint32_t s = 0 ; s < cur_seq_set . size ( ) ; + + s ) {
// no overlap with existing sequence sets:
if ( ! ( cur_seq_set [ s ] & seq_set [ i ] ) . none ( ) ) {
add = false ;
break ;
}
}
2025-07-04 09:08:59 +03:00
// accept only increasing sequence ids
if ( sequential ) {
add = add & & ( cur_seq_set . empty ( ) | | batch . seq_id [ i ] [ 0 ] = = last_seq_id + 1 ) ;
}
2025-06-20 10:14:14 +03:00
if ( add ) {
cur_seq_set . push_back ( seq_set [ i ] ) ;
2025-07-04 09:08:59 +03:00
last_seq_id = batch . seq_id [ i ] [ 0 ] ;
2025-06-20 10:14:14 +03:00
if ( cur_seq_set . size ( ) > n_ubatch ) {
break ;
}
}
}
const uint32_t n_seqs = cur_seq_set . size ( ) ;
// we are done
if ( n_seqs = = 0 ) {
return { } ;
}
// the current batch index of each sequence set
std : : vector < int32_t > cur_idx ( n_seqs , 0 ) ;
for ( uint32_t s = 0 ; s < n_seqs ; + + s ) {
while ( used [ seq_set_map [ cur_seq_set [ s ] ] [ cur_idx [ s ] ] ] ) {
+ + cur_idx [ s ] ;
}
}
// the list of batch indices for each sequence set
// at the end we will concat these to get the final ubatch
std : : vector < idx_vec_t > idxs_per_seq ( n_seqs ) ;
while ( true ) {
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
// if we haven't reached n_ubatch
bool can_expand = true ;
for ( uint32_t s = 0 ; s < n_seqs ; + + s ) {
if ( cur_idx [ s ] > = ( int32_t ) seq_set_map [ cur_seq_set [ s ] ] . size ( ) ) {
can_expand = false ;
break ;
}
}
if ( ! can_expand ) {
break ;
}
for ( uint32_t s = 0 ; s < n_seqs ; + + s ) {
const int32_t idx = seq_set_map [ cur_seq_set [ s ] ] [ cur_idx [ s ] ] ;
idxs_per_seq [ s ] . push_back ( idx ) ;
used [ idx ] = true ;
2025-07-04 09:04:59 +03:00
+ + n_used ;
2025-06-20 10:14:14 +03:00
+ + cur_idx [ s ] ;
}
if ( ( idxs_per_seq [ 0 ] . size ( ) + 1 ) * n_seqs > n_ubatch ) {
break ;
}
}
// concat the per-sequence-set lists
std : : vector < int32_t > idxs ;
for ( uint32_t s = 0 ; s < n_seqs ; + + s ) {
idxs . insert ( idxs . end ( ) , idxs_per_seq [ s ] . begin ( ) , idxs_per_seq [ s ] . end ( ) ) ;
}
return ubatch_add ( idxs , n_seqs , true ) ;
}
llama_ubatch llama_batch_allocr : : split_seq ( uint32_t n_ubatch ) {
// find the first unused token
uint32_t cur_idx = 0 ;
while ( cur_idx < used . size ( ) & & used [ cur_idx ] ) {
+ + cur_idx ;
}
// we are done
if ( cur_idx > = used . size ( ) ) {
return { } ;
}
// this is the starting sequence set
// we allow adding tokens only if their sequence set is a subset of the current sequence set
auto cur_seq_set = seq_set [ cur_idx ] ;
std : : vector < int32_t > idxs ;
while ( true ) {
idxs . push_back ( cur_idx ) ;
used [ cur_idx ] = true ;
2025-07-04 09:04:59 +03:00
+ + n_used ;
2025-06-20 10:14:14 +03:00
if ( idxs . size ( ) > = n_ubatch ) {
break ;
}
do {
+ + cur_idx ;
} while ( cur_idx < get_n_tokens ( ) & & ( used [ cur_idx ] | | ( ( cur_seq_set & seq_set [ cur_idx ] ) ! = seq_set [ cur_idx ] ) ) ) ;
if ( cur_idx = = get_n_tokens ( ) ) {
break ;
}
cur_seq_set = seq_set [ cur_idx ] ;
}
return ubatch_add ( idxs , 1 , true ) ;
}
2025-06-13 13:47:55 +03:00
void llama_batch_allocr : : clear ( ) {
n_outputs = 0 ;
batch = { } ;
2025-06-20 10:14:14 +03:00
pos . clear ( ) ;
n_seq_id . clear ( ) ;
seq_id . clear ( ) ;
seq_id_unq . clear ( ) ;
output . clear ( ) ;
2025-06-15 09:18:37 +03:00
for ( auto & cur : seq_pos ) {
cur . clear ( ) ;
}
for ( auto & cur : seq_cpl ) {
std : : fill ( cur . begin ( ) , cur . end ( ) , false ) ;
}
2025-06-20 10:14:14 +03:00
seq_set . clear ( ) ;
seq_set_map . clear ( ) ;
std : : fill ( seq_idx . begin ( ) , seq_idx . end ( ) , - 1 ) ;
}
llama_ubatch llama_batch_allocr : : ubatch_add ( const std : : vector < int32_t > & idxs , uint32_t n_seqs , bool equal_seqs ) {
const uint32_t n_tokens = idxs . size ( ) ;
assert ( n_tokens % n_seqs = = 0 ) ;
2025-07-17 19:08:33 +03:00
auto udata = std : : make_shared < llama_ubatch : : data_t > ( ) ;
2025-06-20 10:14:14 +03:00
const int32_t n_pos_cur = batch . embd ? n_pos_per_embd : 1 ;
const int64_t n_embd_all = batch . embd ? ( int64_t ) n_tokens * n_embd : 0 ;
const int64_t n_pos_all = ( int64_t ) n_tokens * n_pos_cur ;
2025-07-17 19:08:33 +03:00
udata - > token . resize ( n_tokens ) ;
udata - > embd . resize ( n_embd_all ) ;
udata - > pos . resize ( n_pos_all ) ;
udata - > n_seq_id . resize ( n_tokens ) ;
udata - > seq_id . resize ( n_tokens ) ;
udata - > seq_id_unq . resize ( 0 ) ;
udata - > seq_idx . resize ( LLAMA_MAX_SEQ , - 1 ) ;
udata - > output . resize ( n_tokens ) ;
2025-06-20 10:14:14 +03:00
seq_set_t seq_set_unq ;
for ( size_t i = 0 ; i < idxs . size ( ) ; + + i ) {
if ( batch . token ) {
2025-07-17 19:08:33 +03:00
udata - > token [ i ] = batch . token [ idxs [ i ] ] ;
2025-06-20 10:14:14 +03:00
}
if ( batch . embd ) {
2025-07-17 19:08:33 +03:00
memcpy ( udata - > embd . data ( ) + i * n_embd , batch . embd + ( int64_t ) idxs [ i ] * n_embd , n_embd * sizeof ( float ) ) ;
2025-06-20 10:14:14 +03:00
}
for ( int j = 0 ; j < n_pos_cur ; + + j ) {
2025-07-17 19:08:33 +03:00
udata - > pos [ j * n_tokens + i ] = batch . pos [ j * batch . n_tokens + idxs [ i ] ] ;
2025-06-20 10:14:14 +03:00
}
2025-07-17 19:08:33 +03:00
udata - > n_seq_id [ i ] = batch . n_seq_id [ idxs [ i ] ] ;
udata - > seq_id [ i ] = batch . seq_id [ idxs [ i ] ] ;
udata - > output [ i ] = batch . logits [ idxs [ i ] ] ;
2025-06-20 10:14:14 +03:00
2025-07-17 19:08:33 +03:00
for ( int s = 0 ; s < udata - > n_seq_id [ i ] ; + + s ) {
seq_set_unq . set ( udata - > seq_id [ i ] [ s ] ) ;
2025-06-20 10:14:14 +03:00
}
2025-07-17 19:08:33 +03:00
if ( udata - > output [ i ] ) {
2025-06-20 10:14:14 +03:00
out_ids . push_back ( idxs [ i ] ) ;
}
}
2025-07-16 16:35:42 +03:00
for ( uint32_t s = 0 ; s < n_seq_max ; + + s ) {
2025-06-20 10:14:14 +03:00
if ( seq_set_unq . test ( s ) ) {
2025-07-17 19:08:33 +03:00
udata - > seq_idx [ s ] = udata - > seq_id_unq . size ( ) ;
udata - > seq_id_unq . push_back ( s ) ;
2025-06-20 10:14:14 +03:00
}
}
llama_ubatch res {
2025-07-17 19:08:33 +03:00
/*.b_equal_seqs =*/ equal_seqs ,
2025-06-20 10:14:14 +03:00
/*.n_tokens =*/ n_tokens ,
/*.n_seq_tokens =*/ n_tokens / n_seqs ,
/*.n_seqs =*/ n_seqs ,
2025-07-17 19:08:33 +03:00
/*.n_seqs_unq =*/ ( uint32_t ) udata - > seq_id_unq . size ( ) ,
/*.token =*/ batch . token ? udata - > token . data ( ) : nullptr ,
/*.embd =*/ batch . embd ? udata - > embd . data ( ) : nullptr ,
/*.pos =*/ udata - > pos . data ( ) ,
/*.n_seq_id =*/ udata - > n_seq_id . data ( ) ,
/*.seq_id =*/ udata - > seq_id . data ( ) ,
/*.seq_id_unq =*/ udata - > seq_id_unq . data ( ) ,
/*.seq_idx =*/ udata - > seq_idx . data ( ) ,
/*.output =*/ udata - > output . data ( ) ,
/*.data =*/ std : : move ( udata ) ,
2025-06-20 10:14:14 +03:00
} ;
if ( debug > 0 ) {
2025-07-17 19:08:33 +03:00
LLAMA_LOG_DEBUG ( " %s: added ubatch to split: \n " , __func__ ) ;
2025-06-20 10:14:14 +03:00
ubatch_print ( res , debug ) ;
}
return res ;
}
void llama_batch_allocr : : ubatch_print ( const llama_ubatch & ubatch , int debug ) {
if ( debug > 0 ) {
2025-07-17 19:08:33 +03:00
LLAMA_LOG_DEBUG ( " %s: equal_seqs = %d \n " , __func__ , ubatch . equal_seqs ( ) ) ;
2025-06-20 10:14:14 +03:00
LLAMA_LOG_DEBUG ( " %s: n_tokens = %d \n " , __func__ , ubatch . n_tokens ) ;
LLAMA_LOG_DEBUG ( " %s: n_seq_tokens = %d \n " , __func__ , ubatch . n_seq_tokens ) ;
LLAMA_LOG_DEBUG ( " %s: n_seqs = %d \n " , __func__ , ubatch . n_seqs ) ;
LLAMA_LOG_DEBUG ( " %s: n_seqs_unq = %d \n " , __func__ , ubatch . n_seqs_unq ) ;
std : : stringstream ss_seq_id_unq ;
std : : stringstream ss_seq_idx ;
ss_seq_id_unq < < " [ " ;
ss_seq_idx < < " [ " ;
for ( uint32_t s = 0 ; s < ubatch . n_seqs_unq ; + + s ) {
ss_seq_id_unq < < ubatch . seq_id_unq [ s ] < < " " ;
}
for ( uint32_t s = 0 ; s < LLAMA_MAX_SEQ ; + + s ) {
if ( ubatch . seq_idx [ s ] > = 0 ) {
ss_seq_idx < < ubatch . seq_idx [ s ] % 10 ;
} else {
ss_seq_idx < < " . " ;
}
}
ss_seq_id_unq < < " ] " ;
ss_seq_idx < < " ] " ;
LLAMA_LOG_DEBUG ( " %s: token = %p \n " , __func__ , ( void * ) ubatch . token ) ;
LLAMA_LOG_DEBUG ( " %s: embd = %p \n " , __func__ , ( void * ) ubatch . embd ) ;
LLAMA_LOG_DEBUG ( " %s: pos = %p \n " , __func__ , ( void * ) ubatch . pos ) ;
LLAMA_LOG_DEBUG ( " %s: n_seq_id = %p \n " , __func__ , ( void * ) ubatch . n_seq_id ) ;
LLAMA_LOG_DEBUG ( " %s: seq_id = %p \n " , __func__ , ( void * ) ubatch . seq_id ) ;
LLAMA_LOG_DEBUG ( " %s: seq_id_unq = %s \n " , __func__ , ss_seq_id_unq . str ( ) . c_str ( ) ) ;
LLAMA_LOG_DEBUG ( " %s: seq_idx = %s \n " , __func__ , ss_seq_idx . str ( ) . c_str ( ) ) ;
LLAMA_LOG_DEBUG ( " %s: output = %p \n " , __func__ , ( void * ) ubatch . output ) ;
LLAMA_LOG_DEBUG ( " %s: n_outputs = %d \n " , __func__ , n_outputs ) ;
if ( debug > 1 ) {
int seq_id_max = 0 ;
for ( uint32_t i = 0 ; i < ubatch . n_tokens ; + + i ) {
for ( int s = 0 ; s < ubatch . n_seq_id [ i ] ; + + s ) {
for ( int s = 0 ; s < ubatch . n_seq_id [ i ] ; + + s ) {
seq_id_max = std : : max ( seq_id_max , ubatch . seq_id [ i ] [ s ] ) ;
}
}
}
+ + seq_id_max ;
LLAMA_LOG_DEBUG ( " %s: token = [ \n " , __func__ ) ;
for ( uint32_t i = 0 ; i < ubatch . n_tokens ; + + i ) {
std : : vector < int8_t > seq_id ( seq_id_max ) ;
for ( int s = 0 ; s < ubatch . n_seq_id [ i ] ; + + s ) {
seq_id [ ubatch . seq_id [ i ] [ s ] ] = 1 ;
}
std : : stringstream ss ;
for ( int s = 0 ; s < seq_id_max ; + + s ) {
if ( seq_id [ s ] ) {
ss < < s % 10 ;
} else {
ss < < " . " ;
}
}
if ( ubatch . token ) {
LLAMA_LOG_DEBUG ( " %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d \n " ,
__func__ , i , ubatch . token [ i ] , vocab - > token_to_piece ( ubatch . token [ i ] ) . c_str ( ) ,
ubatch . pos [ i ] , ubatch . n_seq_id [ i ] , ss . str ( ) . c_str ( ) , ubatch . output [ i ] ) ;
} else {
LLAMA_LOG_DEBUG ( " %s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d \n " ,
__func__ , i , ubatch . pos [ i ] , ubatch . n_seq_id [ i ] , ss . str ( ) . c_str ( ) , ubatch . output [ i ] ) ;
}
}
LLAMA_LOG_DEBUG ( " %s: ] \n " , __func__ ) ;
}
}
2025-01-03 10:18:53 +02:00
}
//
// interface implementation
//
struct llama_batch llama_batch_get_one (
llama_token * tokens ,
int32_t n_tokens ) {
return {
2025-06-20 10:14:14 +03:00
/*n_tokens =*/ n_tokens ,
/*tokens =*/ tokens ,
/*embd =*/ nullptr ,
/*pos =*/ nullptr ,
/*n_seq_id =*/ nullptr ,
/*seq_id =*/ nullptr ,
/*logits =*/ nullptr ,
2025-01-03 10:18:53 +02:00
} ;
}
struct llama_batch llama_batch_init ( int32_t n_tokens_alloc , int32_t embd , int32_t n_seq_max ) {
llama_batch batch = {
2025-06-20 10:14:14 +03:00
/*n_tokens =*/ 0 ,
/*tokens =*/ nullptr ,
/*embd =*/ nullptr ,
/*pos =*/ nullptr ,
/*n_seq_id =*/ nullptr ,
/*seq_id =*/ nullptr ,
/*logits =*/ nullptr ,
2025-01-03 10:18:53 +02:00
} ;
if ( embd ) {
batch . embd = ( float * ) malloc ( sizeof ( float ) * n_tokens_alloc * embd ) ;
} else {
batch . token = ( llama_token * ) malloc ( sizeof ( llama_token ) * n_tokens_alloc ) ;
}
batch . pos = ( llama_pos * ) malloc ( sizeof ( llama_pos ) * n_tokens_alloc ) ;
batch . n_seq_id = ( int32_t * ) malloc ( sizeof ( int32_t ) * n_tokens_alloc ) ;
batch . seq_id = ( llama_seq_id * * ) malloc ( sizeof ( llama_seq_id * ) * ( n_tokens_alloc + 1 ) ) ;
for ( int i = 0 ; i < n_tokens_alloc ; + + i ) {
batch . seq_id [ i ] = ( llama_seq_id * ) malloc ( sizeof ( llama_seq_id ) * n_seq_max ) ;
}
batch . seq_id [ n_tokens_alloc ] = nullptr ;
batch . logits = ( int8_t * ) malloc ( sizeof ( int8_t ) * n_tokens_alloc ) ;
return batch ;
}
void llama_batch_free ( struct llama_batch batch ) {
if ( batch . token ) free ( batch . token ) ;
if ( batch . embd ) free ( batch . embd ) ;
if ( batch . pos ) free ( batch . pos ) ;
if ( batch . n_seq_id ) free ( batch . n_seq_id ) ;
if ( batch . seq_id ) {
for ( int i = 0 ; batch . seq_id [ i ] ! = nullptr ; + + i ) {
free ( batch . seq_id [ i ] ) ;
}
free ( batch . seq_id ) ;
}
if ( batch . logits ) free ( batch . logits ) ;
}