907 lines
32 KiB
Python
907 lines
32 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import neuronxcc.nki.isa as nisa
|
|
import neuronxcc.nki.language as nl
|
|
import numpy as np
|
|
import torch
|
|
from neuronxcc import nki
|
|
from neuronxcc.nki.language import par_dim
|
|
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def is_power_of_2(x):
|
|
return x > 0 and (x & (x - 1)) == 0
|
|
|
|
|
|
@nki.jit
|
|
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
|
"""
|
|
Load block tables from HBM into SRAM
|
|
|
|
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
|
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
|
"""
|
|
B_P_SIZE = 128
|
|
|
|
# reshape as `(num_tiles, num_blocks_per_tile)`
|
|
assert len(block_tables_hbm.shape) == 1
|
|
(num_total_blocks, ) = block_tables_hbm.shape
|
|
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
|
block_tables_hbm = block_tables_hbm.reshape(
|
|
(num_tiles, num_blocks_per_tile))
|
|
|
|
block_tables_sbuf = nl.zeros(
|
|
(ceil_div(num_tiles,
|
|
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
|
dtype=nl.int32,
|
|
)
|
|
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
|
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
|
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
|
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
|
dtype=nl.int32,
|
|
mask=(i_p + i * B_P_SIZE < num_tiles),
|
|
)
|
|
return block_tables_sbuf
|
|
|
|
|
|
@nki.jit
|
|
def transform_block_tables_for_indirect_load(
|
|
block_tables,
|
|
block_size_tiling_factor,
|
|
num_head,
|
|
head_id,
|
|
):
|
|
"""
|
|
This function does two things:
|
|
1. calculate new `block_tables` for a `head_id` after flattening
|
|
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
|
2. transpose the result so that `block_table` for each tile is mapped to
|
|
SBUF Partition dimension for vectorized DMA
|
|
|
|
Tiling trick to further improve DMA performance:
|
|
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
|
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
|
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
|
fully utilize hardware parallelization. The solution is to tile `block_size`
|
|
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
|
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
|
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
|
|
|
Note:
|
|
We don't further tile D dimension as small DMA size also hurts performance.
|
|
"""
|
|
B_P_SIZE = 128
|
|
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
|
block_tables.shape)
|
|
assert num_tiles_per_partition == B_P_SIZE
|
|
assert is_power_of_2(
|
|
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
|
|
|
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
|
block_tables_transposed = nl.ndarray(
|
|
(
|
|
num_loads,
|
|
par_dim(B_P_SIZE),
|
|
num_partitions * num_tiles_per_partition,
|
|
),
|
|
dtype=nl.int32,
|
|
)
|
|
|
|
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
|
if num_head > 1:
|
|
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
|
head_id = nl.transpose(
|
|
head_id.broadcast_to((1, num_tiles_per_partition)))
|
|
if num_blocks_per_tile > 1:
|
|
head_id = head_id.broadcast_to(
|
|
(num_tiles_per_partition, num_blocks_per_tile))
|
|
|
|
if block_size_tiling_factor > 1:
|
|
broadcast_shape = (
|
|
num_tiles_per_partition,
|
|
num_blocks_per_tile,
|
|
block_size_tiling_factor,
|
|
)
|
|
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
|
dtype=nl.int32).broadcast_to(broadcast_shape)
|
|
|
|
for partition_id in nl.affine_range(num_partitions):
|
|
block_tables_partition = block_tables[partition_id]
|
|
if num_head > 1:
|
|
# fuse num_block and num_head dimension
|
|
block_tables_partition = block_tables_partition * num_head + head_id
|
|
|
|
if block_size_tiling_factor > 1:
|
|
# need to apply block size tiling trick
|
|
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
|
block_tables_partition = ((block_tables_partition *
|
|
block_size_tiling_factor).reshape(
|
|
(num_tiles_per_partition,
|
|
num_blocks_per_tile,
|
|
1)).broadcast_to(broadcast_shape))
|
|
new_block_tables = block_tables_partition + offset
|
|
new_block_tables = new_block_tables.reshape(
|
|
(num_tiles_per_partition, B_P_SIZE))
|
|
else:
|
|
new_block_tables = block_tables_partition
|
|
|
|
# transpose the block table so that it can be used by vector DGE
|
|
for i in nl.affine_range(num_loads):
|
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
|
i_f = (partition_id * num_tiles_per_partition +
|
|
nl.arange(num_tiles_per_partition)[None, :])
|
|
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
|
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
|
return block_tables_transposed
|
|
|
|
|
|
@nki.jit
|
|
def load_kv_tile_from_cache(
|
|
cur_k_tile,
|
|
cur_v_tile,
|
|
kv_cache,
|
|
block_tables,
|
|
large_k_tile_idx,
|
|
num_blocks_per_large_tile,
|
|
tiled_block_size,
|
|
B_P_SIZE,
|
|
B_D_SIZE,
|
|
):
|
|
"""
|
|
Load KV cache and transform Key and Value into layout required by Matmul
|
|
|
|
Vectorized DMA Load layout:
|
|
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
|
|
|
Layout used by attention matmuls:
|
|
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
|
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
|
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
|
"""
|
|
# load key cache
|
|
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
|
for load_idx in nl.affine_range(num_loads):
|
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
|
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
|
large_k_tile_idx], i_f])
|
|
if cur_k_tile.dtype != loaded.dtype:
|
|
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
|
# Transpose SBUF tensor using PE
|
|
for tb_i in nl.affine_range(tiled_block_size):
|
|
cur_k_tile[
|
|
:,
|
|
nl.ds(
|
|
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
|
B_P_SIZE,
|
|
),
|
|
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
|
|
|
# load value cache
|
|
for load_idx in nl.affine_range(num_loads):
|
|
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
|
large_k_tile_idx], i_f])
|
|
if cur_v_tile.dtype != loaded.dtype:
|
|
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
|
cur_v_tile[
|
|
:,
|
|
nl.ds(
|
|
load_idx * tiled_block_size * B_D_SIZE,
|
|
tiled_block_size * B_D_SIZE,
|
|
),
|
|
] = loaded
|
|
|
|
|
|
@nki.jit
|
|
def transpose_p_local(p_local_transposed,
|
|
p_local,
|
|
LARGE_TILE_SZ,
|
|
B_F_SIZE=512):
|
|
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
|
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
|
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
|
buffer=nl.sbuf,
|
|
dtype=p_local.dtype)
|
|
else:
|
|
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
|
buffer=nl.psum,
|
|
dtype=np.float32)
|
|
|
|
for j in nl.affine_range(B_F_SIZE // 128):
|
|
j_128_slice = nl.ds(j * 128, 128)
|
|
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
|
|
|
|
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
|
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
|
p_local[:, i_j_128_slice])
|
|
else:
|
|
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
|
p_local[:, i_j_128_slice])
|
|
|
|
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
|
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
|
|
|
|
|
@nki.jit
|
|
def _flash_attention_core(
|
|
q_local_tile,
|
|
k,
|
|
v,
|
|
o_buffer,
|
|
l_buffer,
|
|
m_buffer,
|
|
kernel_dtype,
|
|
acc_type,
|
|
tile_mask,
|
|
use_causal_mask,
|
|
q_tile_idx=None,
|
|
initialize=False,
|
|
LARGE_TILE_SZ=2048,
|
|
B_P_SIZE=128,
|
|
B_F_SIZE=512,
|
|
B_D_SIZE=128,
|
|
qk_res_buffer=None,
|
|
):
|
|
"""
|
|
The flash attention core function to calculate self attention between a tile
|
|
of q and a block of K and V.
|
|
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
|
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
|
be split into size B_F_SIZE tiles
|
|
|
|
The results are stored in the following three buffers
|
|
o_buffer: (B_P_SIZE, d)
|
|
l_buffer: (B_P_SIZE, 1)
|
|
m_buffer: (B_P_SIZE, 1)
|
|
|
|
All IO buffers are in SBUF.
|
|
"""
|
|
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
|
|
|
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
|
buffer=nl.sbuf,
|
|
dtype=acc_type)
|
|
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
|
|
dtype=acc_type)
|
|
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
|
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
|
|
|
if use_causal_mask:
|
|
# mask are used to only apply computation to the lower half of the
|
|
# matrix, which reduce the arithmetic intensity by up to 50%
|
|
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
|
>= k_i * B_F_SIZE)
|
|
else:
|
|
multiplication_required_selection = True
|
|
|
|
if multiplication_required_selection:
|
|
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
|
dtype=np.float32,
|
|
buffer=nl.psum) # (128, 512)
|
|
qk_psum[:, :] = nl.matmul(q_local_tile,
|
|
k[:, k_i_b_f_slice],
|
|
transpose_x=True) # (p(128), 512)
|
|
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
|
tile_mask[:, k_i_b_f_slice],
|
|
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
|
-9984.0,
|
|
dtype=acc_type,
|
|
)
|
|
else:
|
|
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
|
|
|
# Calculate max of the current tile
|
|
max_local[:, k_i] = nisa.tensor_reduce(
|
|
np.max,
|
|
qk_res_buf[:, k_i_b_f_slice],
|
|
axis=(1, ),
|
|
dtype=acc_type,
|
|
negate=False,
|
|
)
|
|
|
|
if qk_res_buffer is not None:
|
|
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
|
|
|
|
max_ = nisa.tensor_reduce(
|
|
np.max,
|
|
max_local[:, :],
|
|
axis=(1, ),
|
|
dtype=acc_type,
|
|
negate=False,
|
|
)
|
|
|
|
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
|
dtype=o_buffer.dtype)
|
|
|
|
if initialize:
|
|
m_buffer[:, 0] = nl.copy(max_)
|
|
m_current = max_
|
|
else:
|
|
m_previous = nl.copy(m_buffer[:, 0])
|
|
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
|
|
|
m_current = m_buffer[:, 0]
|
|
# Compute scaling factor
|
|
alpha = nisa.activation(
|
|
np.exp,
|
|
m_previous,
|
|
bias=-1 * m_current,
|
|
scale=1.0,
|
|
)
|
|
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
|
|
|
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
|
dtype=kernel_dtype)
|
|
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
|
|
|
p_partial_sum = nl.ndarray(
|
|
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
|
dtype=acc_type,
|
|
)
|
|
|
|
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
|
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
|
|
|
# compute exp(qk - max)
|
|
# Compute partial row - tile sum of exp(qk - max))
|
|
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
|
|
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
|
|
np.exp,
|
|
qk_res_buf[:, k_r_i_reduce_slice],
|
|
bias=-1 * m_current,
|
|
scale=1.0,
|
|
reduce_op=nl.add,
|
|
reduce_res=p_partial_sum[:, k_r_i],
|
|
dtype=kernel_dtype,
|
|
)
|
|
|
|
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
|
|
|
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
|
dtype=kernel_dtype)
|
|
transpose_p_local(
|
|
p_local_transposed=p_local_transposed,
|
|
p_local=p_local,
|
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
|
B_F_SIZE=B_F_SIZE,
|
|
)
|
|
|
|
pv_psum = nl.zeros(
|
|
(par_dim(B_P_SIZE), B_D_SIZE),
|
|
dtype=np.float32,
|
|
buffer=nl.psum,
|
|
)
|
|
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
|
pv_psum[:, :] += nl.matmul(
|
|
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
|
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
|
transpose_x=True,
|
|
) # (128, 128) (p(Br), d)
|
|
|
|
if initialize:
|
|
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
|
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
|
else:
|
|
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
|
|
|
l_prev = l_buffer[:, 0]
|
|
l_exp = nl.add(
|
|
nl.exp(nl.subtract(l_prev, m_current)),
|
|
ps,
|
|
)
|
|
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
|
|
|
|
|
@nki.jit
|
|
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
|
B_P_SIZE = 128
|
|
B_D_SIZE = v_hbm_tile.shape[-1]
|
|
loaded = nl.load(v_hbm_tile[
|
|
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
|
:,
|
|
])
|
|
if cur_v_tile.dtype != loaded.dtype:
|
|
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
|
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
|
|
|
|
|
@nki.jit
|
|
def flash_paged_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
kv_cache,
|
|
block_tables,
|
|
mask,
|
|
softmax_scale=None,
|
|
mixed_precision=True,
|
|
LARGE_TILE_SZ=2048,
|
|
return_debug_tensors=False,
|
|
):
|
|
"""
|
|
Flash PagedAttention Forward Kernel.
|
|
|
|
IO tensor layouts:
|
|
- query: shape (1, n_heads, d, seq_q)
|
|
- key: shape (1, n_kv_heads, d, seq_k)
|
|
- value: shape (1, n_kv_heads, seq_v, d)
|
|
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
|
- block_tables: (num_active_blocks, )
|
|
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
|
- o: shape (1, n_heads, seq_q, d)
|
|
|
|
- This kernel requires seq_k == seq_v
|
|
- We use continuous batching by default, so the batch dimension is
|
|
always 1, and different requests are concatenated along sequence
|
|
dimension.
|
|
- We use paged cache blocks (kv_cache) to store KV cache.
|
|
|
|
IO tensor dtypes:
|
|
- This kernel assumes all IO tensors have the same dtype except for
|
|
block_tables (int32) and mask (int32)
|
|
- If mixed_precision is True, then all Tensor Engine operation will be
|
|
performed in bfloat16 and accumulation will be performed in float32.
|
|
Otherwise the intermediates will be in the same type as the inputs.
|
|
|
|
Compile-time Constants:
|
|
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
|
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
|
is set to `true`, if false, we use same precision as input types
|
|
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
|
computation reduction
|
|
|
|
GQA support Notes:
|
|
the spmd kernel for launching kernel should be on kv_heads instead of
|
|
nheads
|
|
|
|
Example usage:
|
|
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
|
|
usage: `flash_fwd[b, h](q, k, v, ...)`
|
|
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
|
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
|
"""
|
|
B_F_SIZE = 512
|
|
B_P_SIZE = 128
|
|
b, h, d, seqlen_q = query.shape
|
|
B_D_SIZE = d
|
|
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
|
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
|
q_h_per_k_h = h // k_h
|
|
assert b == 1, f"invalid batch size {b=}"
|
|
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
|
cache_shape = (2, num_blocks, k_h, block_size, d)
|
|
assert (tuple(kv_cache.shape) == cache_shape
|
|
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
|
assert key is None or tuple(key.shape) == (
|
|
1,
|
|
k_h,
|
|
d,
|
|
seqlen_q,
|
|
), f"key shape {key.shape} mismatch!"
|
|
assert value is None or tuple(value.shape) == (
|
|
1,
|
|
k_h,
|
|
seqlen_q,
|
|
d,
|
|
), f"value shape {value.shape} mismatch!"
|
|
|
|
assert (
|
|
nl.program_ndim() == 2
|
|
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
|
batch_id = nl.program_id(axis=0)
|
|
head_id = nl.program_id(axis=1)
|
|
|
|
(num_active_blocks, ) = block_tables.shape
|
|
context_kv_len = num_active_blocks * block_size
|
|
assert (
|
|
LARGE_TILE_SZ % B_F_SIZE == 0
|
|
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
|
assert (context_kv_len % LARGE_TILE_SZ == 0
|
|
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
|
|
|
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
|
assert is_power_of_2(
|
|
num_blocks_per_large_tile
|
|
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
|
if seqlen_q > B_F_SIZE:
|
|
MAX_REDUCTION_TILE = 2048
|
|
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
|
assert (
|
|
seqlen_q % MAX_REDUCTION_TILE == 0
|
|
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
|
else:
|
|
assert (seqlen_q % B_F_SIZE == 0
|
|
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
|
|
|
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
|
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
|
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
|
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
|
|
|
o = nl.ndarray((b, h, seqlen_q, d),
|
|
dtype=query.dtype,
|
|
buffer=nl.shared_hbm)
|
|
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
if return_debug_tensors:
|
|
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
|
|
dtype=acc_type,
|
|
buffer=nl.shared_hbm)
|
|
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
|
|
dtype=acc_type,
|
|
buffer=nl.shared_hbm)
|
|
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
|
|
dtype=acc_type,
|
|
buffer=nl.shared_hbm)
|
|
qk_res_buffer = nl.zeros(
|
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
|
|
dtype=acc_type,
|
|
buffer=nl.sbuf,
|
|
lazy_initialization=True,
|
|
)
|
|
block_tables_sbuf = load_block_tables(
|
|
block_tables_hbm=block_tables,
|
|
num_tiles=num_large_k_tile,
|
|
num_blocks_per_tile=num_blocks_per_large_tile,
|
|
)
|
|
|
|
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
|
if num_blocks_per_large_tile < B_P_SIZE:
|
|
# we checked num_blocks_per_tile is a power of 2
|
|
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
|
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
|
# We assume block_size >= block_size_tiling_factor
|
|
assert block_size % block_size_tiling_factor == 0
|
|
else:
|
|
block_size_tiling_factor = 1
|
|
tiled_block_size = block_size // block_size_tiling_factor
|
|
|
|
# Indirect DMA load must be placed along Partition Dimension
|
|
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
|
block_tables_sbuf,
|
|
block_size_tiling_factor=block_size_tiling_factor,
|
|
num_head=k_h,
|
|
head_id=head_id,
|
|
)
|
|
|
|
# Flatten KV cache to be 3D for loading into SBUF
|
|
new_cache_shape = (
|
|
2,
|
|
num_blocks * k_h * block_size_tiling_factor,
|
|
tiled_block_size * d,
|
|
)
|
|
kv_cache = kv_cache.reshape(new_cache_shape)
|
|
|
|
# Global Flash Attention accumulators
|
|
o_buffer = nl.zeros(
|
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
|
|
dtype=acc_type,
|
|
buffer=nl.sbuf,
|
|
lazy_initialization=True,
|
|
)
|
|
l_buffer = nl.zeros(
|
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
|
dtype=acc_type,
|
|
buffer=nl.sbuf,
|
|
lazy_initialization=True,
|
|
)
|
|
m_buffer = nl.zeros(
|
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
|
dtype=acc_type,
|
|
buffer=nl.sbuf,
|
|
lazy_initialization=True,
|
|
)
|
|
|
|
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
|
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
|
cur_k_tile = nl.ndarray(
|
|
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
|
dtype=kernel_dtype,
|
|
)
|
|
cur_v_tile = nl.ndarray(
|
|
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
|
dtype=kernel_dtype,
|
|
)
|
|
load_kv_tile_from_cache(
|
|
cur_k_tile=cur_k_tile,
|
|
cur_v_tile=cur_v_tile,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables_sbuf,
|
|
large_k_tile_idx=large_k_tile_idx,
|
|
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
|
tiled_block_size=tiled_block_size,
|
|
B_P_SIZE=B_P_SIZE,
|
|
B_D_SIZE=B_D_SIZE,
|
|
)
|
|
|
|
for i in nl.affine_range(n_tile_q):
|
|
cur_mask = nl.load(mask[
|
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
|
])
|
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
|
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
|
nl.ds(i *
|
|
B_P_SIZE, B_P_SIZE)])
|
|
if q_sbuf_tile.dtype != kernel_dtype:
|
|
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
|
|
|
_flash_attention_core(
|
|
q_local_tile=q_tile,
|
|
k=cur_k_tile,
|
|
v=cur_v_tile,
|
|
o_buffer=o_buffer[i, i_q_h],
|
|
l_buffer=l_buffer[i, i_q_h],
|
|
m_buffer=m_buffer[i, i_q_h],
|
|
kernel_dtype=kernel_dtype,
|
|
acc_type=acc_type,
|
|
tile_mask=cur_mask,
|
|
use_causal_mask=False,
|
|
q_tile_idx=i,
|
|
initialize=large_k_tile_idx == 0,
|
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
|
B_P_SIZE=B_P_SIZE,
|
|
B_F_SIZE=B_F_SIZE,
|
|
B_D_SIZE=B_D_SIZE,
|
|
)
|
|
|
|
# compute attention between input query, key and value
|
|
if key is not None and value is not None:
|
|
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
|
LARGE_TILE_SZ = seqlen_q
|
|
|
|
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
|
dtype=kernel_dtype)
|
|
cur_v_tile = nl.ndarray(
|
|
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
|
dtype=kernel_dtype,
|
|
)
|
|
|
|
loaded = nl.load(key[batch_id, head_id, :, :])
|
|
if loaded.dtype != kernel_dtype:
|
|
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
|
cur_k_tile[:, :] = loaded
|
|
|
|
v_hbm_tile = value[batch_id, head_id]
|
|
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
|
load_v_tile(
|
|
v_hbm_tile=v_hbm_tile,
|
|
cur_v_tile=cur_v_tile,
|
|
large_tile_idx=0,
|
|
v_i=v_i,
|
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
|
)
|
|
|
|
for i in nl.affine_range(n_tile_q):
|
|
cur_mask = nl.load(mask[
|
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
|
])
|
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
|
|
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
|
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
|
nl.ds(i *
|
|
B_P_SIZE, B_P_SIZE)])
|
|
if q_sbuf_tile.dtype != kernel_dtype:
|
|
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
|
_flash_attention_core(
|
|
q_local_tile=q_tile,
|
|
k=cur_k_tile,
|
|
v=cur_v_tile,
|
|
o_buffer=o_buffer[i, i_q_h],
|
|
l_buffer=l_buffer[i, i_q_h],
|
|
m_buffer=m_buffer[i, i_q_h],
|
|
kernel_dtype=kernel_dtype,
|
|
acc_type=acc_type,
|
|
tile_mask=cur_mask,
|
|
use_causal_mask=True,
|
|
q_tile_idx=i,
|
|
initialize=False,
|
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
|
B_P_SIZE=B_P_SIZE,
|
|
B_F_SIZE=B_F_SIZE,
|
|
B_D_SIZE=B_D_SIZE,
|
|
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
|
if qk_res_buffer is not None else None),
|
|
)
|
|
|
|
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
|
for i in nl.affine_range(n_tile_q):
|
|
out = nl.multiply(
|
|
o_buffer[i, i_q_h],
|
|
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
|
dtype=kernel_dtype,
|
|
)
|
|
|
|
nl.store(
|
|
o[
|
|
batch_id,
|
|
head_id * q_h_per_k_h + i_q_h,
|
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
:,
|
|
],
|
|
out,
|
|
)
|
|
# maximum and summation statistics
|
|
if return_debug_tensors:
|
|
nl.store(
|
|
hbm_m_buffer[
|
|
batch_id,
|
|
head_id * q_h_per_k_h + i_q_h,
|
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
],
|
|
m_buffer[i, i_q_h, :, :],
|
|
)
|
|
nl.store(
|
|
hbm_l_buffer[
|
|
batch_id,
|
|
head_id * q_h_per_k_h + i_q_h,
|
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
],
|
|
l_buffer[i, i_q_h],
|
|
)
|
|
nl.store(
|
|
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
|
qk_res_buffer[batch_id, i_q_h, :, :],
|
|
)
|
|
|
|
if return_debug_tensors:
|
|
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
|
|
return o
|
|
|
|
|
|
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
|
"""
|
|
Reorder the mask to make it compatible with the flash attention kernel.
|
|
|
|
We vectorize KV cache read to improve DMA utilization. However, the layout
|
|
that maximizes DMA bandwidth changes the order tokens are consumed.
|
|
|
|
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
|
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
|
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
|
tokens. Therefore, the tokens are visited in a strided way.
|
|
|
|
To make sure mask matches the order tokens are consumed, we need to properly
|
|
transpose mask.
|
|
"""
|
|
total_query_len, total_seq_len = mask.shape
|
|
context_kv_len = total_seq_len - total_query_len
|
|
|
|
B_P_SIZE = 128
|
|
assert (LARGE_TILE_SZ
|
|
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
|
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
|
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
|
if tiled_block_size > 1:
|
|
# Mask reordering is needed when tiled_block_size > 1
|
|
device = mask.device
|
|
mask = mask.cpu()
|
|
context_mask = mask[:, :context_kv_len]
|
|
context_mask = context_mask.view(
|
|
total_query_len,
|
|
context_kv_len // LARGE_TILE_SZ,
|
|
num_tiled_blocks // B_P_SIZE,
|
|
B_P_SIZE,
|
|
tiled_block_size,
|
|
)
|
|
context_mask = context_mask.transpose(3, 4).reshape(
|
|
total_query_len, context_kv_len)
|
|
new_mask = mask[:, context_kv_len:]
|
|
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
|
else:
|
|
return mask
|
|
|
|
|
|
def flash_attn_varlen_nkifunc(
|
|
query,
|
|
key,
|
|
value,
|
|
kv_cache,
|
|
block_table,
|
|
attn_mask,
|
|
n_kv_head=None,
|
|
head_size=None,
|
|
LARGE_TILE_SZ=2048,
|
|
mixed_precision=True,
|
|
):
|
|
"""
|
|
Compute flash paged attention for variable length sequences.
|
|
|
|
This function is a wrapper around the flash attention NKI kernel. It takes
|
|
in the following arguments:
|
|
- query: (1, n_heads, d, seq_q)
|
|
- key: (1, n_kv_heads, d, seq_k)
|
|
- value: (1, n_kv_heads, seq_v, d)
|
|
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
|
- block_tables: (n_active_blocks, )
|
|
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
|
|
|
Notes:
|
|
- attn_mask must be reordered outside using `reorder_context_mask`
|
|
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
|
for better DMA throughput
|
|
"""
|
|
if n_kv_head is None:
|
|
n_kv_head = kv_cache.shape[2]
|
|
assert kv_cache.shape[0] == 2
|
|
assert kv_cache.shape[2] == n_kv_head
|
|
if head_size is None:
|
|
head_size = kv_cache.shape[-1]
|
|
|
|
kwargs = dict(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_table,
|
|
mask=attn_mask,
|
|
softmax_scale=1.0 / (head_size**0.5),
|
|
mixed_precision=mixed_precision,
|
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
|
)
|
|
|
|
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
|
return o
|
|
|
|
|
|
def reshape_and_cache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Writes key-value pairs to the KV cache at specified positions.
|
|
|
|
Args:
|
|
key (torch.Tensor): Key tensor with shape
|
|
(num_tokens, n_kv_head, d_head)
|
|
value (torch.Tensor): Value tensor with shape
|
|
(num_tokens, n_kv_head, d_head)
|
|
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
|
(2, num_blocks, n_kv_head, block_size, d_head)
|
|
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
|
with shape (num_tokens)
|
|
|
|
Returns:
|
|
None: Updates the kv_cache tensor in-place
|
|
"""
|
|
block_size = kv_cache.size(3)
|
|
n_kv_head = key.size(1)
|
|
|
|
# Calculate indices with explicit floor division
|
|
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
|
block_offsets = slot_mapping % block_size
|
|
|
|
# Create the head indices tensor
|
|
head_indices = torch.arange(n_kv_head, device=key.device)
|
|
|
|
# Update caches using index_put_
|
|
kv_cache.index_put_(
|
|
(torch.tensor([0], device=key.device), block_indices[:, None],
|
|
head_indices[None, :], block_offsets[:, None]), key)
|
|
|
|
kv_cache.index_put_(
|
|
(torch.tensor([1], device=key.device), block_indices[:, None],
|
|
head_indices[None, :], block_offsets[:, None]), value)
|