# SPDX-License-Identifier: Apache-2.0 import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np import torch from neuronxcc import nki from neuronxcc.nki.language import par_dim def ceil_div(a, b): return (a + b - 1) // b def is_power_of_2(x): return x > 0 and (x & (x - 1)) == 0 @nki.jit def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): """ Load block tables from HBM into SRAM `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. """ B_P_SIZE = 128 # reshape as `(num_tiles, num_blocks_per_tile)` assert len(block_tables_hbm.shape) == 1 (num_total_blocks, ) = block_tables_hbm.shape assert num_blocks_per_tile * num_tiles == num_total_blocks block_tables_hbm = block_tables_hbm.reshape( (num_tiles, num_blocks_per_tile)) block_tables_sbuf = nl.zeros( (ceil_div(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), dtype=nl.int32, ) for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(num_blocks_per_tile)[None, :] block_tables_sbuf[i, i_p, i_f] = nl.load( block_tables_hbm[i_p + i * B_P_SIZE, i_f], dtype=nl.int32, mask=(i_p + i * B_P_SIZE < num_tiles), ) return block_tables_sbuf @nki.jit def transform_block_tables_for_indirect_load( block_tables, block_size_tiling_factor, num_head, head_id, ): """ This function does two things: 1. calculate new `block_tables` for a `head_id` after flattening `num_block`, `num_head`, and `block_size_tiling_factor` dimensions 2. transpose the result so that `block_table` for each tile is mapped to SBUF Partition dimension for vectorized DMA Tiling trick to further improve DMA performance: Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M blocks of a given `head_id` from HBM, the load `cache[block_tables, head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not fully utilize hardware parallelization. The solution is to tile `block_size` into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. Note: We don't further tile D dimension as small DMA size also hurts performance. """ B_P_SIZE = 128 num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( block_tables.shape) assert num_tiles_per_partition == B_P_SIZE assert is_power_of_2( num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) block_tables_transposed = nl.ndarray( ( num_loads, par_dim(B_P_SIZE), num_partitions * num_tiles_per_partition, ), dtype=nl.int32, ) # prepare iota ahead of time to avoid repeatedly using Gpsimd if num_head > 1: head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) head_id = nl.transpose( head_id.broadcast_to((1, num_tiles_per_partition))) if num_blocks_per_tile > 1: head_id = head_id.broadcast_to( (num_tiles_per_partition, num_blocks_per_tile)) if block_size_tiling_factor > 1: broadcast_shape = ( num_tiles_per_partition, num_blocks_per_tile, block_size_tiling_factor, ) offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], dtype=nl.int32).broadcast_to(broadcast_shape) for partition_id in nl.affine_range(num_partitions): block_tables_partition = block_tables[partition_id] if num_head > 1: # fuse num_block and num_head dimension block_tables_partition = block_tables_partition * num_head + head_id if block_size_tiling_factor > 1: # need to apply block size tiling trick assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE block_tables_partition = ((block_tables_partition * block_size_tiling_factor).reshape( (num_tiles_per_partition, num_blocks_per_tile, 1)).broadcast_to(broadcast_shape)) new_block_tables = block_tables_partition + offset new_block_tables = new_block_tables.reshape( (num_tiles_per_partition, B_P_SIZE)) else: new_block_tables = block_tables_partition # transpose the block table so that it can be used by vector DGE for i in nl.affine_range(num_loads): i_p = nl.arange(B_P_SIZE)[:, None] i_f = (partition_id * num_tiles_per_partition + nl.arange(num_tiles_per_partition)[None, :]) block_tables_transposed[i, i_p, i_f] = nl.transpose( new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) return block_tables_transposed @nki.jit def load_kv_tile_from_cache( cur_k_tile, cur_v_tile, kv_cache, block_tables, large_k_tile_idx, num_blocks_per_large_tile, tiled_block_size, B_P_SIZE, B_D_SIZE, ): """ Load KV cache and transform Key and Value into layout required by Matmul Vectorized DMA Load layout: Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) Layout used by attention matmuls: Key: (par_dim(B_D_SIZE), seqlen_kv) Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) """ # load key cache num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) for load_idx in nl.affine_range(num_loads): i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, large_k_tile_idx], i_f]) if cur_k_tile.dtype != loaded.dtype: loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) # Transpose SBUF tensor using PE for tb_i in nl.affine_range(tiled_block_size): cur_k_tile[ :, nl.ds( load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, B_P_SIZE, ), ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) # load value cache for load_idx in nl.affine_range(num_loads): loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, large_k_tile_idx], i_f]) if cur_v_tile.dtype != loaded.dtype: loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] cur_v_tile[ :, nl.ds( load_idx * tiled_block_size * B_D_SIZE, tiled_block_size * B_D_SIZE, ), ] = loaded @nki.jit def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ, B_F_SIZE=512): for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): if nisa.get_nc_version() == nisa.nc_version.gen3: p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), buffer=nl.sbuf, dtype=p_local.dtype) else: p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), buffer=nl.psum, dtype=np.float32) for j in nl.affine_range(B_F_SIZE // 128): j_128_slice = nl.ds(j * 128, 128) i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) if nisa.get_nc_version() == nisa.nc_version.gen3: p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( p_local[:, i_j_128_slice]) else: p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( p_local[:, i_j_128_slice]) p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( p_local_t_tmp, dtype=p_local_transposed.dtype) @nki.jit def _flash_attention_core( q_local_tile, k, v, o_buffer, l_buffer, m_buffer, kernel_dtype, acc_type, tile_mask, use_causal_mask, q_tile_idx=None, initialize=False, LARGE_TILE_SZ=2048, B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, qk_res_buffer=None, ): """ The flash attention core function to calculate self attention between a tile of q and a block of K and V. The q_local_tile has (B_P_SIZE, B_D_SIZE) The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will be split into size B_F_SIZE tiles The results are stored in the following three buffers o_buffer: (B_P_SIZE, d) l_buffer: (B_P_SIZE, 1) m_buffer: (B_P_SIZE, 1) All IO buffers are in SBUF. """ num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type) for k_i in nl.affine_range(num_k_tile_per_large_tile): k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) if use_causal_mask: # mask are used to only apply computation to the lower half of the # matrix, which reduce the arithmetic intensity by up to 50% multiplication_required_selection = (q_tile_idx * B_P_SIZE >= k_i * B_F_SIZE) else: multiplication_required_selection = True if multiplication_required_selection: qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=np.float32, buffer=nl.psum) # (128, 512) qk_psum[:, :] = nl.matmul(q_local_tile, k[:, k_i_b_f_slice], transpose_x=True) # (p(128), 512) qk_res_buf[:, k_i_b_f_slice] = nl.where( tile_mask[:, k_i_b_f_slice], qk_psum[:, nl.ds(0, B_F_SIZE)], -9984.0, dtype=acc_type, ) else: qk_res_buf[:, k_i_b_f_slice] = -9984.0 # Calculate max of the current tile max_local[:, k_i] = nisa.tensor_reduce( np.max, qk_res_buf[:, k_i_b_f_slice], axis=(1, ), dtype=acc_type, negate=False, ) if qk_res_buffer is not None: qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) max_ = nisa.tensor_reduce( np.max, max_local[:, :], axis=(1, ), dtype=acc_type, negate=False, ) o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=o_buffer.dtype) if initialize: m_buffer[:, 0] = nl.copy(max_) m_current = max_ else: m_previous = nl.copy(m_buffer[:, 0]) m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) m_current = m_buffer[:, 0] # Compute scaling factor alpha = nisa.activation( np.exp, m_previous, bias=-1 * m_current, scale=1.0, ) o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) p_partial_sum = nl.ndarray( (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type, ) for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) # compute exp(qk - max) # Compute partial row - tile sum of exp(qk - max)) # FIXME : Use activation accumulate to accumulate over k_r_i loop ? p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( np.exp, qk_res_buf[:, k_r_i_reduce_slice], bias=-1 * m_current, scale=1.0, reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i], dtype=kernel_dtype, ) ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) transpose_p_local( p_local_transposed=p_local_transposed, p_local=p_local, LARGE_TILE_SZ=LARGE_TILE_SZ, B_F_SIZE=B_F_SIZE, ) pv_psum = nl.zeros( (par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum, ) for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): pv_psum[:, :] += nl.matmul( p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], transpose_x=True, ) # (128, 128) (p(Br), d) if initialize: o_buffer[:, :] = nl.copy(pv_psum[:, :]) l_buffer[:, 0] = nl.add(nl.log(ps), max_) else: o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) l_prev = l_buffer[:, 0] l_exp = nl.add( nl.exp(nl.subtract(l_prev, m_current)), ps, ) l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) @nki.jit def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): B_P_SIZE = 128 B_D_SIZE = v_hbm_tile.shape[-1] loaded = nl.load(v_hbm_tile[ nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :, ]) if cur_v_tile.dtype != loaded.dtype: loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded @nki.jit def flash_paged_attention( query, key, value, kv_cache, block_tables, mask, softmax_scale=None, mixed_precision=True, LARGE_TILE_SZ=2048, return_debug_tensors=False, ): """ Flash PagedAttention Forward Kernel. IO tensor layouts: - query: shape (1, n_heads, d, seq_q) - key: shape (1, n_kv_heads, d, seq_k) - value: shape (1, n_kv_heads, seq_v, d) - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) - block_tables: (num_active_blocks, ) - mask: (seq_q, num_active_blocks * block_size + seq_q) - o: shape (1, n_heads, seq_q, d) - This kernel requires seq_k == seq_v - We use continuous batching by default, so the batch dimension is always 1, and different requests are concatenated along sequence dimension. - We use paged cache blocks (kv_cache) to store KV cache. IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype except for block_tables (int32) and mask (int32) - If mixed_percision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. Compile-time Constants: - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention computation reduction GQA support Notes: the spmd kernel for launching kernel should be on kv_heads instead of nheads Example usage: MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] usage: `flash_fwd[b, h](q, k, v, ...)` GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] usage: `flash_fwd[b, kv_h](q, k, v, ...)` """ B_F_SIZE = 512 B_P_SIZE = 128 b, h, d, seqlen_q = query.shape B_D_SIZE = d n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine _, num_blocks, k_h, block_size, _ = kv_cache.shape q_h_per_k_h = h // k_h assert b == 1, f"invalid batch size {b=}" assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" cache_shape = (2, num_blocks, k_h, block_size, d) assert (tuple(kv_cache.shape) == cache_shape ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" assert key is None or tuple(key.shape) == ( 1, k_h, d, seqlen_q, ), f"key shape {key.shape} mismatch!" assert value is None or tuple(value.shape) == ( 1, k_h, seqlen_q, d, ), f"value shape {value.shape} mismatch!" assert ( nl.program_ndim() == 2 ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" batch_id = nl.program_id(axis=0) head_id = nl.program_id(axis=1) (num_active_blocks, ) = block_tables.shape context_kv_len = num_active_blocks * block_size assert ( LARGE_TILE_SZ % B_F_SIZE == 0 ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" assert (context_kv_len % LARGE_TILE_SZ == 0 ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" num_blocks_per_large_tile = LARGE_TILE_SZ // block_size assert is_power_of_2( num_blocks_per_large_tile ), f"{num_blocks_per_large_tile=} is expected of be power of 2" if seqlen_q > B_F_SIZE: MAX_REDUCTION_TILE = 2048 if seqlen_q // 2 > MAX_REDUCTION_TILE: assert ( seqlen_q % MAX_REDUCTION_TILE == 0 ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" else: assert (seqlen_q % B_F_SIZE == 0 ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype softmax_scale = softmax_scale or (1.0 / (d**0.5)) num_large_k_tile = context_kv_len // LARGE_TILE_SZ o = nl.ndarray((b, h, seqlen_q, d), dtype=query.dtype, buffer=nl.shared_hbm) hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( None, None, None, None, ) if return_debug_tensors: hbm_l_buffer = nl.ndarray((b, h, seqlen_q), dtype=acc_type, buffer=nl.shared_hbm) hbm_m_buffer = nl.ndarray((b, h, seqlen_q), dtype=acc_type, buffer=nl.shared_hbm) hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), dtype=acc_type, buffer=nl.shared_hbm) qk_res_buffer = nl.zeros( (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, ) block_tables_sbuf = load_block_tables( block_tables_hbm=block_tables, num_tiles=num_large_k_tile, num_blocks_per_tile=num_blocks_per_large_tile, ) # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient if num_blocks_per_large_tile < B_P_SIZE: # we checked num_blocks_per_tile is a power of 2 assert B_P_SIZE % num_blocks_per_large_tile == 0 block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile # We assume block_size >= block_size_tiling_factor assert block_size % block_size_tiling_factor == 0 else: block_size_tiling_factor = 1 tiled_block_size = block_size // block_size_tiling_factor # Indirect DMA load must be placed along Partition Dimension block_tables_sbuf = transform_block_tables_for_indirect_load( block_tables_sbuf, block_size_tiling_factor=block_size_tiling_factor, num_head=k_h, head_id=head_id, ) # Flatten KV cache to be 3D for loading into SBUF new_cache_shape = ( 2, num_blocks * k_h * block_size_tiling_factor, tiled_block_size * d, ) kv_cache = kv_cache.reshape(new_cache_shape) # Global Flash Attention accumulators o_buffer = nl.zeros( (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, ) l_buffer = nl.zeros( (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, ) m_buffer = nl.zeros( (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, ) for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) cur_k_tile = nl.ndarray( (par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype, ) cur_v_tile = nl.ndarray( (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), dtype=kernel_dtype, ) load_kv_tile_from_cache( cur_k_tile=cur_k_tile, cur_v_tile=cur_v_tile, kv_cache=kv_cache, block_tables=block_tables_sbuf, large_k_tile_idx=large_k_tile_idx, num_blocks_per_large_tile=num_blocks_per_large_tile, tiled_block_size=tiled_block_size, B_P_SIZE=B_P_SIZE, B_D_SIZE=B_D_SIZE, ) for i in nl.affine_range(n_tile_q): cur_mask = nl.load(mask[ nl.ds(i * B_P_SIZE, B_P_SIZE), nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) if q_sbuf_tile.dtype != kernel_dtype: q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], kernel_dtype=kernel_dtype, acc_type=acc_type, tile_mask=cur_mask, use_causal_mask=False, q_tile_idx=i, initialize=large_k_tile_idx == 0, LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, ) # compute attention between input query, key and value if key is not None and value is not None: B_F_SIZE = min(seqlen_q, B_F_SIZE) LARGE_TILE_SZ = seqlen_q cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) cur_v_tile = nl.ndarray( (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), dtype=kernel_dtype, ) loaded = nl.load(key[batch_id, head_id, :, :]) if loaded.dtype != kernel_dtype: loaded = nl.copy(loaded, dtype=kernel_dtype) cur_k_tile[:, :] = loaded v_hbm_tile = value[batch_id, head_id] for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): load_v_tile( v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, large_tile_idx=0, v_i=v_i, LARGE_TILE_SZ=LARGE_TILE_SZ, ) for i in nl.affine_range(n_tile_q): cur_mask = nl.load(mask[ nl.ds(i * B_P_SIZE, B_P_SIZE), nl.ds(context_kv_len, LARGE_TILE_SZ), ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) if q_sbuf_tile.dtype != kernel_dtype: q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], kernel_dtype=kernel_dtype, acc_type=acc_type, tile_mask=cur_mask, use_causal_mask=True, q_tile_idx=i, initialize=False, LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, qk_res_buffer=(qk_res_buffer[i, i_q_h] if qk_res_buffer is not None else None), ) # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # for i_q_h in nl.affine_range(q_h_per_k_h): for i in nl.affine_range(n_tile_q): out = nl.multiply( o_buffer[i, i_q_h], nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), dtype=kernel_dtype, ) nl.store( o[ batch_id, head_id * q_h_per_k_h + i_q_h, nl.ds(i * B_P_SIZE, B_P_SIZE), :, ], out, ) # maximum and summation statistics if return_debug_tensors: nl.store( hbm_m_buffer[ batch_id, head_id * q_h_per_k_h + i_q_h, nl.ds(i * B_P_SIZE, B_P_SIZE), ], m_buffer[i, i_q_h, :, :], ) nl.store( hbm_l_buffer[ batch_id, head_id * q_h_per_k_h + i_q_h, nl.ds(i * B_P_SIZE, B_P_SIZE), ], l_buffer[i, i_q_h], ) nl.store( hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], qk_res_buffer[batch_id, i_q_h, :, :], ) if return_debug_tensors: return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res return o def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): """ Reorder the mask to make it compatible with the flash attention kernel. We vectorize KV cache read to improve DMA utilization. However, the layout that maximizes DMA bandwidth changes the order tokens are consumed. The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And each step the engine consumes a column (rather than a row) of B_P_SIZE tokens. Therefore, the tokens are visited in a strided way. To make sure mask matches the order tokens are consumed, we need to properly transpose mask. """ total_query_len, total_seq_len = mask.shape context_kv_len = total_seq_len - total_query_len B_P_SIZE = 128 assert (LARGE_TILE_SZ >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks if tiled_block_size > 1: # Mask reordering is needed when tiled_block_size > 1 device = mask.device mask = mask.cpu() context_mask = mask[:, :context_kv_len] context_mask = context_mask.view( total_query_len, context_kv_len // LARGE_TILE_SZ, num_tiled_blocks // B_P_SIZE, B_P_SIZE, tiled_block_size, ) context_mask = context_mask.transpose(3, 4).reshape( total_query_len, context_kv_len) new_mask = mask[:, context_kv_len:] return torch.concat([context_mask, new_mask], dim=1).to(device) else: return mask def flash_attn_varlen_nkifunc( query, key, value, kv_cache, block_table, attn_mask, n_kv_head=None, head_size=None, LARGE_TILE_SZ=2048, mixed_precision=True, ): """ Compute flash paged attention for variable length sequences. This function is a wrapper around the flash attention NKI kernel. It takes in the following arguments: - query: (1, n_heads, d, seq_q) - key: (1, n_kv_heads, d, seq_k) - value: (1, n_kv_heads, seq_v, d) - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) - block_tables: (n_active_blocks, ) - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) Notes: - attn_mask must be reordered outside using `reorder_context_mask` - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) for better DMA throughput """ if n_kv_head is None: n_kv_head = kv_cache.shape[2] assert kv_cache.shape[0] == 2 assert kv_cache.shape[2] == n_kv_head if head_size is None: head_size = kv_cache.shape[-1] kwargs = dict( query=query, key=key, value=value, kv_cache=kv_cache, block_tables=block_table, mask=attn_mask, softmax_scale=1.0 / (head_size**0.5), mixed_precision=mixed_precision, LARGE_TILE_SZ=LARGE_TILE_SZ, ) o = flash_paged_attention[1, n_kv_head](**kwargs) return o def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: """ Writes key-value pairs to the KV cache at specified positions. Args: key (torch.Tensor): Key tensor with shape (num_tokens, n_kv_head, d_head) value (torch.Tensor): Value tensor with shape (num_tokens, n_kv_head, d_head) kv_cache (torch.Tensor): Key/value cache tensor with shape (2, num_blocks, n_kv_head, block_size, d_head) slot_mapping (torch.Tensor): Mapping tensor indicating cache positions with shape (num_tokens) Returns: None: Updates the kv_cache tensor in-place """ block_size = kv_cache.size(3) n_kv_head = key.size(1) # Calculate indices with explicit floor division block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") block_offsets = slot_mapping % block_size # Create the head indices tensor head_indices = torch.arange(n_kv_head, device=key.device) # Update caches using index_put_ kv_cache.index_put_( (torch.tensor([0], device=key.device), block_indices[:, None], head_indices[None, :], block_offsets[:, None]), key) kv_cache.index_put_( (torch.tensor([1], device=key.device), block_indices[:, None], head_indices[None, :], block_offsets[:, None]), value)