[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/attention/ops/__init__.py
Normal file
0
vllm/attention/ops/__init__.py
Normal file
@@ -0,0 +1,433 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v, # (#tokens, n_heads, head_size)
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
sparse_layout,
|
||||
*,
|
||||
block_size=64,
|
||||
q_block_size=None,
|
||||
max_seqlen=None):
|
||||
# split q to blocks
|
||||
|
||||
assert isinstance(sparse_layout, (list, tuple))
|
||||
|
||||
_, n_heads, head_size = q.shape
|
||||
batch_size = cu_seqlens_k.size(0) - 1
|
||||
q_block_size = q_block_size or block_size
|
||||
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
assert q.size(2) == k.size(2)
|
||||
# TODO(linxihui): allow k, v to have different head_size
|
||||
assert k.shape == v.shape
|
||||
assert cu_seqlens_k.dim() == 1
|
||||
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
if q.size(0) == batch_size: # decoding only
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=cu_seqlens_k.dtype,
|
||||
device=cu_seqlens_k.device,
|
||||
)
|
||||
elif q.size(0) == k.size(0):
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
else:
|
||||
raise ValueError("cu_seqlens_q must be specified\
|
||||
if it mix of prefilling and decoding.")
|
||||
else:
|
||||
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||
|
||||
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||
|
||||
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||
|
||||
if max_seqlen:
|
||||
assert k_lens.max() <= max_seqlen
|
||||
|
||||
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||
|
||||
q_batch_ids = torch.tensor(
|
||||
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
q_start_sids = torch.tensor(
|
||||
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
|
||||
out = q.new_empty(q.shape)
|
||||
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||
|
||||
layout_crow_indices, layout_col_indices = sparse_layout
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
|
||||
decoding_only = (q_lens == 1).all().item()
|
||||
grid = (len(q_start_sids), n_heads, 1)
|
||||
|
||||
_fwd_kernel_batch_inference[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
sm_scale,
|
||||
cu_seqlens_q[:-1],
|
||||
cu_seqlens_q[1:],
|
||||
cu_seqlens_k[:-1],
|
||||
cu_seqlens_k[1:],
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
0,
|
||||
*q.stride(),
|
||||
0,
|
||||
*k.stride(),
|
||||
0,
|
||||
*v.stride(),
|
||||
0,
|
||||
*out.stride(),
|
||||
layout_crow_indices,
|
||||
layout_col_indices,
|
||||
*layout_crow_indices.stride(),
|
||||
*layout_col_indices.stride(),
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM=False,
|
||||
D_HEAD=head_size,
|
||||
BLOCK_M=q_block_size,
|
||||
BLOCK_N=block_size,
|
||||
BLOCK_D=block_d,
|
||||
BLOCK_M_LOADING=(16 if decoding_only else
|
||||
q_block_size), # smaller for decoding
|
||||
EVEN_D=block_d == head_size,
|
||||
num_warps=1 if decoding_only else 4,
|
||||
num_stages=3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
LAST_K_BLOCK: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||
if LAST_K_BLOCK | M_LT_N:
|
||||
qk += tl.where(
|
||||
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# flash-attn2
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
# update m_i
|
||||
m_i = m_ij
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"M_LT_N":
|
||||
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||
})
|
||||
@triton.jit
|
||||
def _fwd_kernel_batch_inference(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
sm_scale,
|
||||
q_batch_starts,
|
||||
q_batch_ends,
|
||||
k_batch_starts,
|
||||
k_batch_ends,
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
layout_crow_ptr,
|
||||
layout_col_ptr,
|
||||
layout_crow_stride_h,
|
||||
layout_crow_stride_m,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
NOTATION:
|
||||
pid: position id
|
||||
sid: storage id
|
||||
sbid: storage block id
|
||||
pbid: position block id
|
||||
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
|
||||
TODO(linxihui):
|
||||
Optimize grouped-attn
|
||||
"""
|
||||
off_zm = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
off_z = tl.program_id(2)
|
||||
Q += off_z * stride_qb
|
||||
K += off_z * stride_kb
|
||||
V += off_z * stride_vb
|
||||
Out += off_z * stride_ob
|
||||
start_m = off_zm
|
||||
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||
else:
|
||||
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
q_pbid * layout_crow_stride_m)
|
||||
|
||||
# TODO(linxihui): load at once, with any Triton version
|
||||
# that supports `tl.split`, e.g., Triton 3.0
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||
|
||||
sm_scale *= (
|
||||
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||
)
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
False,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_end - 1,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
True,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
# flash-attn 2
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
||||
239
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
239
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
|
||||
class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
max_seqlen,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
block_size,
|
||||
device=None,
|
||||
dtype=None,
|
||||
homo_head=False,
|
||||
active_head_range=None,
|
||||
q_block_size=None,
|
||||
use_spda=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = current_platform.is_rocm() or \
|
||||
current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
or device.type == "cpu" else torch.half)
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.max_seqlen = max_seqlen
|
||||
self.local_blocks = local_blocks
|
||||
self.vert_stride = vert_stride
|
||||
self.use_spda = use_spda
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.q_block_size = q_block_size
|
||||
self.homo_head = homo_head
|
||||
self.active_head_range = active_head_range
|
||||
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
|
||||
homo_head)
|
||||
|
||||
sparse_layout, sparse_pattern, self.dense_attn_mask = (
|
||||
self.get_attn_pattern(dtype, device))
|
||||
|
||||
if q_block_size is not None and q_block_size != block_size:
|
||||
if q_block_size > block_size:
|
||||
assert q_block_size % block_size == 0
|
||||
blocks_to_merge = q_block_size // block_size
|
||||
shape = sparse_pattern.shape
|
||||
sparse_pattern = sparse_pattern.view(shape[0], -1,
|
||||
blocks_to_merge,
|
||||
shape[-1])
|
||||
sparse_pattern = sparse_pattern.sum(2)
|
||||
sparse_layout = dense_to_crow_col(sparse_pattern)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Does not support smaller q_block_size. It will be slower."
|
||||
)
|
||||
|
||||
self.sparse_layout = sparse_layout
|
||||
|
||||
def get_attn_pattern(self, dtype, device):
|
||||
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
|
||||
self.n_heads,
|
||||
self.max_seqlen,
|
||||
self.max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size=self.block_size,
|
||||
local_blocks=self.local_blocks,
|
||||
vert_stride=self.vert_stride,
|
||||
homo_head=self.homo_head,
|
||||
return_dense=self.use_spda,
|
||||
dense_mask_type="bias",
|
||||
)
|
||||
if (not self.homo_head) and (self.active_head_range is not None):
|
||||
assert isinstance(self.active_head_range, tuple)
|
||||
assert (len(self.active_head_range) == 2)
|
||||
h_start, h_end = self.active_head_range
|
||||
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
|
||||
if self.use_spda:
|
||||
dense_attn_mask = dense_attn_mask[h_start:h_end]
|
||||
return sparse_layout, sparse_pattern, dense_attn_mask
|
||||
|
||||
def varlen_attn(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=None,
|
||||
sm_scale=None):
|
||||
"""
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,),
|
||||
indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify is when q is a mix of
|
||||
prefilling and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert (
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
), "Requires compute capability of 8 or above (Ampere or newer) to use \
|
||||
Triton kernel."
|
||||
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
|
||||
return blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
self.sparse_layout,
|
||||
block_size=self.block_size,
|
||||
q_block_size=self.q_block_size,
|
||||
max_seqlen=self.max_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
|
||||
"""
|
||||
:param x: (total_tokens, n_heads, head_size)
|
||||
:return: (batch, n_heads, length, head_size)
|
||||
"""
|
||||
x_padded = x.new_empty(
|
||||
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
|
||||
1).unsqueeze(1))
|
||||
return x_padded.flatten(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def transpose_and_unpad(x_padded, cu_seqlens):
|
||||
"""
|
||||
:param x_padded: (batch, n_heads, length, head_size)
|
||||
:return: (total_tokens, n_heads, head_size)
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
total_n_tokens = cu_seqlens[-1]
|
||||
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
|
||||
x_padded.size(3))
|
||||
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
|
||||
return x
|
||||
|
||||
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""For CPU, V100 or other older GPUs.
|
||||
NOTE: torch SPDA supports nested tensor,
|
||||
but seems extremely slow. Choose to pad instead.
|
||||
"""
|
||||
assert (cu_seqlens_q is None or
|
||||
(cu_seqlens_q
|
||||
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
|
||||
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
|
||||
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||
cu_seqlens = cu_seqlens_k.cpu()
|
||||
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
|
||||
if (self.dense_attn_mask.dtype != q.dtype
|
||||
or self.dense_attn_mask.device != q.device):
|
||||
_, _, self.dense_attn_mask = self.get_attn_pattern(
|
||||
q.dtype, q.device)
|
||||
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
|
||||
|
||||
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
|
||||
k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
|
||||
for x in [k, v])
|
||||
spda_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
|
||||
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||
`self.spda`(cpu, Volta, Turing or older)based on
|
||||
the type of device used and cuda compute capability.
|
||||
|
||||
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
|
||||
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||
Default None: same as cu_seqlens_k for prefilling or
|
||||
[0, 1, .., batch_size] for decoding.
|
||||
The only case you need to specify
|
||||
is when q is a mix of prefilling
|
||||
and decoding.
|
||||
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||
|
||||
return: tensor of shape as q.
|
||||
"""
|
||||
assert k.dim() == 3
|
||||
if self.use_spda:
|
||||
return self.spda(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
return self.varlen_attn(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale)
|
||||
246
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
246
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Helper functions for 3D sparse pattern
|
||||
# These function are not optimized and very inefficient.
|
||||
# Avoid calling them too frequent or use a cache mechanism.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class csr_matrix:
|
||||
"""Simple implementation of CSR matrix conversion without scipy.
|
||||
This replaced scipy.sparse.csr_matrix() previously used."""
|
||||
|
||||
def __init__(self, input_array):
|
||||
if not isinstance(input_array, np.ndarray):
|
||||
raise ValueError("Input must be a NumPy array")
|
||||
|
||||
self.shape = input_array.shape
|
||||
rows, cols = self.shape
|
||||
data = []
|
||||
indices = []
|
||||
indptr = [0]
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
if input_array[i, j]:
|
||||
data.append(input_array[i, j])
|
||||
indices.append(j)
|
||||
indptr.append(len(indices))
|
||||
|
||||
self.data = np.array(data)
|
||||
self.indices = np.array(indices)
|
||||
self.indptr = np.array(indptr)
|
||||
|
||||
|
||||
def dense_to_crow_col(x: torch.Tensor):
|
||||
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
||||
NOTE: col_indices padded -1
|
||||
"""
|
||||
device = x.device
|
||||
pad = -1
|
||||
dim = x.dim()
|
||||
assert x.dim() in (2, 3)
|
||||
if x.dim() == 2:
|
||||
x = x[None]
|
||||
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||
max_cols = max(len(xi) for xi in cols)
|
||||
cols = [
|
||||
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
||||
for xi in cols
|
||||
]
|
||||
cols = torch.vstack(cols)
|
||||
if dim == 2:
|
||||
crows = crows[0]
|
||||
cols = cols[0]
|
||||
return crows.to(device), cols.to(device)
|
||||
|
||||
|
||||
def crow_col_to_dense(crows: torch.Tensor,
|
||||
cols: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
dim = crows.dim()
|
||||
if dim == 1:
|
||||
crows = crows[None]
|
||||
cols = cols[None]
|
||||
device = crows.device
|
||||
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
||||
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
||||
x = torch.zeros(shape, dtype=dtype)
|
||||
for i in range(shape[0]):
|
||||
for j in range(shape[1]):
|
||||
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
||||
if dim == 1:
|
||||
x = x[0]
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def dense_to_ccol_row(x: torch.Tensor):
|
||||
"""Similar, but to CSC format"""
|
||||
x = x.transpose(-2, -1)
|
||||
return dense_to_crow_col(x)
|
||||
|
||||
|
||||
def ccol_row_to_dense(ccol: torch.Tensor,
|
||||
rows: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
def _get_sparse_attn_mask_homo_head(
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 128,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
return_dense: bool = False,
|
||||
):
|
||||
"""
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
otherwise, None
|
||||
"""
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[:, None]
|
||||
k_pos = torch.arange(num_blocks)[None]
|
||||
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = (dense_to_crow_col(
|
||||
block_mask_dense[-num_blocks_q:].contiguous()))
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
block_mask_dense_output,
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
||||
mask_dense = 1 - mask_dense
|
||||
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
||||
return mask_dense
|
||||
|
||||
|
||||
def get_head_sliding_step(n_heads: int,
|
||||
vert_stride: int,
|
||||
homo_head: bool = False):
|
||||
if homo_head:
|
||||
return 0
|
||||
return max(1, int(vert_stride / n_heads))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_sparse_attn_mask(
|
||||
n_heads: int,
|
||||
q_len: int,
|
||||
max_seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
block_size: int = 64,
|
||||
local_blocks: int = 4,
|
||||
vert_stride: int = 4,
|
||||
homo_head: bool = True,
|
||||
return_dense: bool = False,
|
||||
dense_mask_type: str = "binary",
|
||||
):
|
||||
"""
|
||||
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||
or "bias" (-inf for skip token, 0 or others)
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
is too big) if `return_dense==True`, otherwise, None
|
||||
"""
|
||||
assert dense_mask_type in ("binary", "bias")
|
||||
if homo_head:
|
||||
with torch.no_grad():
|
||||
(crow, col), block_mask_dense, mask_dense = (
|
||||
_get_sparse_attn_mask_homo_head(
|
||||
q_len,
|
||||
max_seqlen,
|
||||
dtype,
|
||||
device,
|
||||
block_size,
|
||||
local_blocks,
|
||||
vert_stride,
|
||||
return_dense,
|
||||
))
|
||||
crow = crow[None].expand(n_heads, crow.shape[0])
|
||||
col = col[None].expand(n_heads, col.shape[0])
|
||||
if return_dense:
|
||||
mask_dense = mask_dense[None].expand(n_heads,
|
||||
*mask_dense.shape)
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
return (crow, col), block_mask_dense, mask_dense
|
||||
|
||||
with torch.no_grad():
|
||||
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||
q_pos = torch.arange(num_blocks)[None, :, None]
|
||||
k_pos = torch.arange(num_blocks)[None, None]
|
||||
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
||||
mask_vert_strided = [
|
||||
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
||||
vert_stride == 0 for h in range(n_heads)
|
||||
]
|
||||
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||
block_mask_dense = (((q_pos >= k_pos)
|
||||
& ((q_pos - k_pos < local_blocks)
|
||||
| mask_vert_strided)).to(device).to(dtype))
|
||||
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
||||
if return_dense:
|
||||
mask_dense = torch.kron(
|
||||
block_mask_dense,
|
||||
block_mask_dense.new_ones((block_size, block_size)),
|
||||
)
|
||||
causal_mask = torch.tril(torch.ones(
|
||||
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
||||
if dense_mask_type == "bias":
|
||||
mask_dense = binary_mask_to_bias(mask_dense)
|
||||
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
mask_dense,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
dense_to_crow_col(block_mask_dense_output),
|
||||
block_mask_dense,
|
||||
None,
|
||||
)
|
||||
386
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
386
vllm/attention/ops/chunked_prefill_paged_decode.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_paged_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
if filter_by_query_len:
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx +
|
||||
1)
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
if cur_batch_query_len > 1:
|
||||
return
|
||||
else:
|
||||
cur_batch_in_all_start_index = seq_idx
|
||||
|
||||
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
|
||||
0, num_queries_per_kv_padded)
|
||||
|
||||
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
|
||||
query_head_idx[:, None] * query_stride_1)
|
||||
|
||||
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
|
||||
head_mask = head_mask & (query_head_idx < num_query_heads)
|
||||
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
0).to(tl.int1)
|
||||
|
||||
# Q : (num_queries_per_kv, HEAD_SIZE,)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_1 +
|
||||
offs_d[None, :] * stride_v_cache_2 +
|
||||
offs_n[:, None] * stride_v_cache_3)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_1 +
|
||||
(offs_d[:, None] // x) * stride_k_cache_2 +
|
||||
offs_n[None, :] * stride_k_cache_3 +
|
||||
(offs_d[:, None] % x) * stride_k_cache_4)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
|
||||
float("-inf")).to(tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
|
||||
-10000)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (num_queries_per_kv,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
|
||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||
query_head_idx * output_stride_1)
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset[:, None] +
|
||||
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
|
||||
acc,
|
||||
mask=dim_mask[None, :] & head_mask[:, None],
|
||||
)
|
||||
|
||||
|
||||
def chunked_prefill_paged_decode(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[1]**0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if max_query_len > 1:
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = len(seq_lens)
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = key.shape[1]
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
head_size = query.shape[2]
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
key_cache = key_cache.view(target_dtype)
|
||||
value_cache = value_cache.view(target_dtype)
|
||||
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window,
|
||||
kv_cache_dtype, alibi_slopes, sinks,)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = block_table.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
)
|
||||
138
vllm/attention/ops/flashmla.py
Normal file
138
vllm/attention/ops/flashmla.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# if current_platform.is_cuda():
|
||||
# try:
|
||||
# import vllm._flashmla_C # noqa: F401
|
||||
# _flashmla_C_AVAILABLE = True
|
||||
# except ImportError:
|
||||
# _flashmla_C_AVAILABLE = False
|
||||
# else:
|
||||
# _flashmla_C_AVAILABLE = False
|
||||
try :
|
||||
import flash_mla
|
||||
_flashmla_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from flash_mla with %r on MACA Platform", e)
|
||||
_flashmla_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
# if not current_platform.is_cuda():
|
||||
# return False, "FlashMLA is only supported on CUDA devices."
|
||||
# if current_platform.get_device_capability()[0] != 9:
|
||||
# return False, "FlashMLA is only supported on Hopper devices."
|
||||
# if not _flashmla_C_AVAILABLE:
|
||||
# return False, "vllm._flashmla_C is not available, likely was not "\
|
||||
# "compiled due to insufficient nvcc version or a supported arch "\
|
||||
# "(only sm90a currently) was not in the list of target arches to "\
|
||||
# "compile for."
|
||||
if not _flashmla_AVAILABLE:
|
||||
return False, "flash_mla is not available"
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
# return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
# num_heads_per_head_k,
|
||||
# num_heads_k)
|
||||
return flash_mla.flash_mla_interface.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
# if softmax_scale is None:
|
||||
# softmax_scale = q.shape[-1]**(-0.5)
|
||||
# out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
# q,
|
||||
# k_cache,
|
||||
# None,
|
||||
# head_dim_v,
|
||||
# cache_seqlens,
|
||||
# block_table,
|
||||
# softmax_scale,
|
||||
# causal,
|
||||
# tile_scheduler_metadata,
|
||||
# num_splits,
|
||||
# )
|
||||
out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
softmax_scale,
|
||||
causal,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
88
vllm/attention/ops/hpu_paged_attn.py
Normal file
88
vllm/attention/ops/hpu_paged_attn.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm_hpu_extension import cache_ops, ops
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUPagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
block_list: Optional[torch.Tensor]
|
||||
block_mapping: Optional[torch.Tensor]
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_indices: Optional[torch.Tensor]
|
||||
block_offsets: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUPagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [64, 80, 96, 112, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, kv_cache_dtype: str,
|
||||
is_prompt: bool) -> None:
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, is_prompt)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(**kwargs) -> torch.Tensor:
|
||||
return ops.flat_pa(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
195
vllm/attention/ops/ipex_attn.py
Normal file
195
vllm/attention/ops/ipex_attn.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
_use_ipex = True
|
||||
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
|
||||
except (ImportError, AttributeError):
|
||||
_use_ipex = False
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class _PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
tp_rank: int = 0
|
||||
blocksparse_local_blocks: int = 0
|
||||
blocksparse_vert_stride: int = 0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_head_sliding_step: int = 0
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
class _IPEXPagedAttention(_PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ipex_modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache,
|
||||
slot_mapping.flatten().int())
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
block_size = value_cache.shape[2]
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
num_kv_heads,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
).view(num_kv_heads,
|
||||
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
|
||||
ipex_modules.PagedAttention.single_query_cached_kv_attention(
|
||||
output, query.contiguous(), key_cache, value_cache, head_mapping,
|
||||
scale, block_tables, context_lens, block_size, max_context_len,
|
||||
alibi_slopes)
|
||||
|
||||
|
||||
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
|
||||
43
vllm/attention/ops/merge_attn_states.py
Normal file
43
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (current_platform.is_cuda() and supported_dtypes(output)
|
||||
and supported_headdim(output)):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
else:
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states)
|
||||
return merge_attn_states(output, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse)
|
||||
906
vllm/attention/ops/nki_flash_attn.py
Normal file
906
vllm/attention/ops/nki_flash_attn.py
Normal file
@@ -0,0 +1,906 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import neuronxcc.nki.isa as nisa
|
||||
import neuronxcc.nki.language as nl
|
||||
import numpy as np
|
||||
import torch
|
||||
from neuronxcc import nki
|
||||
from neuronxcc.nki.language import par_dim
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||
"""
|
||||
Load block tables from HBM into SRAM
|
||||
|
||||
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
||||
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
|
||||
# reshape as `(num_tiles, num_blocks_per_tile)`
|
||||
assert len(block_tables_hbm.shape) == 1
|
||||
(num_total_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
|
||||
block_tables_sbuf = nl.zeros(
|
||||
(ceil_div(num_tiles,
|
||||
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
||||
dtype=nl.int32,
|
||||
mask=(i_p + i * B_P_SIZE < num_tiles),
|
||||
)
|
||||
return block_tables_sbuf
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transform_block_tables_for_indirect_load(
|
||||
block_tables,
|
||||
block_size_tiling_factor,
|
||||
num_head,
|
||||
head_id,
|
||||
):
|
||||
"""
|
||||
This function does two things:
|
||||
1. calculate new `block_tables` for a `head_id` after flattening
|
||||
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
||||
2. transpose the result so that `block_table` for each tile is mapped to
|
||||
SBUF Partition dimension for vectorized DMA
|
||||
|
||||
Tiling trick to further improve DMA performance:
|
||||
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
||||
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
||||
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
||||
fully utilize hardware parallelization. The solution is to tile `block_size`
|
||||
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
||||
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
||||
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
||||
|
||||
Note:
|
||||
We don't further tile D dimension as small DMA size also hurts performance.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
||||
block_tables.shape)
|
||||
assert num_tiles_per_partition == B_P_SIZE
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||
|
||||
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
||||
block_tables_transposed = nl.ndarray(
|
||||
(
|
||||
num_loads,
|
||||
par_dim(B_P_SIZE),
|
||||
num_partitions * num_tiles_per_partition,
|
||||
),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
|
||||
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
||||
if num_head > 1:
|
||||
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
||||
head_id = nl.transpose(
|
||||
head_id.broadcast_to((1, num_tiles_per_partition)))
|
||||
if num_blocks_per_tile > 1:
|
||||
head_id = head_id.broadcast_to(
|
||||
(num_tiles_per_partition, num_blocks_per_tile))
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
broadcast_shape = (
|
||||
num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
||||
dtype=nl.int32).broadcast_to(broadcast_shape)
|
||||
|
||||
for partition_id in nl.affine_range(num_partitions):
|
||||
block_tables_partition = block_tables[partition_id]
|
||||
if num_head > 1:
|
||||
# fuse num_block and num_head dimension
|
||||
block_tables_partition = block_tables_partition * num_head + head_id
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
# need to apply block size tiling trick
|
||||
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
||||
block_tables_partition = ((block_tables_partition *
|
||||
block_size_tiling_factor).reshape(
|
||||
(num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
1)).broadcast_to(broadcast_shape))
|
||||
new_block_tables = block_tables_partition + offset
|
||||
new_block_tables = new_block_tables.reshape(
|
||||
(num_tiles_per_partition, B_P_SIZE))
|
||||
else:
|
||||
new_block_tables = block_tables_partition
|
||||
|
||||
# transpose the block table so that it can be used by vector DGE
|
||||
for i in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = (partition_id * num_tiles_per_partition +
|
||||
nl.arange(num_tiles_per_partition)[None, :])
|
||||
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
||||
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
||||
return block_tables_transposed
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_kv_tile_from_cache(
|
||||
cur_k_tile,
|
||||
cur_v_tile,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
large_k_tile_idx,
|
||||
num_blocks_per_large_tile,
|
||||
tiled_block_size,
|
||||
B_P_SIZE,
|
||||
B_D_SIZE,
|
||||
):
|
||||
"""
|
||||
Load KV cache and transform Key and Value into layout required by Matmul
|
||||
|
||||
Vectorized DMA Load layout:
|
||||
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
|
||||
Layout used by attention matmuls:
|
||||
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
||||
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
"""
|
||||
# load key cache
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_k_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||
# Transpose SBUF tensor using PE
|
||||
for tb_i in nl.affine_range(tiled_block_size):
|
||||
cur_k_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
),
|
||||
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
||||
|
||||
# load value cache
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
cur_v_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * tiled_block_size * B_D_SIZE,
|
||||
tiled_block_size * B_D_SIZE,
|
||||
),
|
||||
] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transpose_p_local(p_local_transposed,
|
||||
p_local,
|
||||
LARGE_TILE_SZ,
|
||||
B_F_SIZE=512):
|
||||
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.sbuf,
|
||||
dtype=p_local.dtype)
|
||||
else:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.psum,
|
||||
dtype=np.float32)
|
||||
|
||||
for j in nl.affine_range(B_F_SIZE // 128):
|
||||
j_128_slice = nl.ds(j * 128, 128)
|
||||
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
else:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
|
||||
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def _flash_attention_core(
|
||||
q_local_tile,
|
||||
k,
|
||||
v,
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
tile_mask,
|
||||
use_causal_mask,
|
||||
q_tile_idx=None,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=2048,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
qk_res_buffer=None,
|
||||
):
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
||||
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
||||
be split into size B_F_SIZE tiles
|
||||
|
||||
The results are stored in the following three buffers
|
||||
o_buffer: (B_P_SIZE, d)
|
||||
l_buffer: (B_P_SIZE, 1)
|
||||
m_buffer: (B_P_SIZE, 1)
|
||||
|
||||
All IO buffers are in SBUF.
|
||||
"""
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
|
||||
dtype=acc_type)
|
||||
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
if use_causal_mask:
|
||||
# mask are used to only apply computation to the lower half of the
|
||||
# matrix, which reduce the arithmetic intensity by up to 50%
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
multiplication_required_selection = True
|
||||
|
||||
if multiplication_required_selection:
|
||||
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True) # (p(128), 512)
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
tile_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
else:
|
||||
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
||||
|
||||
# Calculate max of the current tile
|
||||
max_local[:, k_i] = nisa.tensor_reduce(
|
||||
np.max,
|
||||
qk_res_buf[:, k_i_b_f_slice],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
if qk_res_buffer is not None:
|
||||
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
|
||||
|
||||
max_ = nisa.tensor_reduce(
|
||||
np.max,
|
||||
max_local[:, :],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=o_buffer.dtype)
|
||||
|
||||
if initialize:
|
||||
m_buffer[:, 0] = nl.copy(max_)
|
||||
m_current = max_
|
||||
else:
|
||||
m_previous = nl.copy(m_buffer[:, 0])
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
||||
|
||||
m_current = m_buffer[:, 0]
|
||||
# Compute scaling factor
|
||||
alpha = nisa.activation(
|
||||
np.exp,
|
||||
m_previous,
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
||||
|
||||
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||
|
||||
p_partial_sum = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
||||
dtype=acc_type,
|
||||
)
|
||||
|
||||
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||
|
||||
# compute exp(qk - max)
|
||||
# Compute partial row - tile sum of exp(qk - max))
|
||||
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
|
||||
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
|
||||
np.exp,
|
||||
qk_res_buf[:, k_r_i_reduce_slice],
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
reduce_op=nl.add,
|
||||
reduce_res=p_partial_sum[:, k_r_i],
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
||||
|
||||
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
transpose_p_local(
|
||||
p_local_transposed=p_local_transposed,
|
||||
p_local=p_local,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
pv_psum = nl.zeros(
|
||||
(par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum,
|
||||
)
|
||||
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
pv_psum[:, :] += nl.matmul(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
||||
transpose_x=True,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
if initialize:
|
||||
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||
else:
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
||||
|
||||
l_prev = l_buffer[:, 0]
|
||||
l_exp = nl.add(
|
||||
nl.exp(nl.subtract(l_prev, m_current)),
|
||||
ps,
|
||||
)
|
||||
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
||||
B_P_SIZE = 128
|
||||
B_D_SIZE = v_hbm_tile.shape[-1]
|
||||
loaded = nl.load(v_hbm_tile[
|
||||
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
||||
:,
|
||||
])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def flash_paged_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
mask,
|
||||
softmax_scale=None,
|
||||
mixed_precision=True,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
):
|
||||
"""
|
||||
Flash PagedAttention Forward Kernel.
|
||||
|
||||
IO tensor layouts:
|
||||
- query: shape (1, n_heads, d, seq_q)
|
||||
- key: shape (1, n_kv_heads, d, seq_k)
|
||||
- value: shape (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (num_active_blocks, )
|
||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||
- o: shape (1, n_heads, seq_q, d)
|
||||
|
||||
- This kernel requires seq_k == seq_v
|
||||
- We use continuous batching by default, so the batch dimension is
|
||||
always 1, and different requests are concatenated along sequence
|
||||
dimension.
|
||||
- We use paged cache blocks (kv_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
block_tables (int32) and mask (int32)
|
||||
- If mixed_precision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
Otherwise the intermediates will be in the same type as the inputs.
|
||||
|
||||
Compile-time Constants:
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
nheads
|
||||
|
||||
Example usage:
|
||||
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
|
||||
usage: `flash_fwd[b, h](q, k, v, ...)`
|
||||
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||
"""
|
||||
B_F_SIZE = 512
|
||||
B_P_SIZE = 128
|
||||
b, h, d, seqlen_q = query.shape
|
||||
B_D_SIZE = d
|
||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
||||
q_h_per_k_h = h // k_h
|
||||
assert b == 1, f"invalid batch size {b=}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||
cache_shape = (2, num_blocks, k_h, block_size, d)
|
||||
assert (tuple(kv_cache.shape) == cache_shape
|
||||
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert key is None or tuple(key.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
d,
|
||||
seqlen_q,
|
||||
), f"key shape {key.shape} mismatch!"
|
||||
assert value is None or tuple(value.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
seqlen_q,
|
||||
d,
|
||||
), f"value shape {value.shape} mismatch!"
|
||||
|
||||
assert (
|
||||
nl.program_ndim() == 2
|
||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||
batch_id = nl.program_id(axis=0)
|
||||
head_id = nl.program_id(axis=1)
|
||||
|
||||
(num_active_blocks, ) = block_tables.shape
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (
|
||||
LARGE_TILE_SZ % B_F_SIZE == 0
|
||||
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
||||
if seqlen_q > B_F_SIZE:
|
||||
MAX_REDUCTION_TILE = 2048
|
||||
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
||||
assert (
|
||||
seqlen_q % MAX_REDUCTION_TILE == 0
|
||||
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
||||
else:
|
||||
assert (seqlen_q % B_F_SIZE == 0
|
||||
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
||||
|
||||
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
|
||||
o = nl.ndarray((b, h, seqlen_q, d),
|
||||
dtype=query.dtype,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if return_debug_tensors:
|
||||
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
qk_res_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
block_tables_sbuf = load_block_tables(
|
||||
block_tables_hbm=block_tables,
|
||||
num_tiles=num_large_k_tile,
|
||||
num_blocks_per_tile=num_blocks_per_large_tile,
|
||||
)
|
||||
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
if num_blocks_per_large_tile < B_P_SIZE:
|
||||
# we checked num_blocks_per_tile is a power of 2
|
||||
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
||||
# We assume block_size >= block_size_tiling_factor
|
||||
assert block_size % block_size_tiling_factor == 0
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
tiled_block_size = block_size // block_size_tiling_factor
|
||||
|
||||
# Indirect DMA load must be placed along Partition Dimension
|
||||
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf,
|
||||
block_size_tiling_factor=block_size_tiling_factor,
|
||||
num_head=k_h,
|
||||
head_id=head_id,
|
||||
)
|
||||
|
||||
# Flatten KV cache to be 3D for loading into SBUF
|
||||
new_cache_shape = (
|
||||
2,
|
||||
num_blocks * k_h * block_size_tiling_factor,
|
||||
tiled_block_size * d,
|
||||
)
|
||||
kv_cache = kv_cache.reshape(new_cache_shape)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
l_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
m_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
|
||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
cur_k_tile = nl.ndarray(
|
||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
load_kv_tile_from_cache(
|
||||
cur_k_tile=cur_k_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables_sbuf,
|
||||
large_k_tile_idx=large_k_tile_idx,
|
||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||
tiled_block_size=tiled_block_size,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=False,
|
||||
q_tile_idx=i,
|
||||
initialize=large_k_tile_idx == 0,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
# compute attention between input query, key and value
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
|
||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
loaded = nl.load(key[batch_id, head_id, :, :])
|
||||
if loaded.dtype != kernel_dtype:
|
||||
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
||||
cur_k_tile[:, :] = loaded
|
||||
|
||||
v_hbm_tile = value[batch_id, head_id]
|
||||
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
load_v_tile(
|
||||
v_hbm_tile=v_hbm_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
large_tile_idx=0,
|
||||
v_i=v_i,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
q_tile_idx=i,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None),
|
||||
)
|
||||
|
||||
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
out = nl.multiply(
|
||||
o_buffer[i, i_q_h],
|
||||
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
nl.store(
|
||||
o[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
:,
|
||||
],
|
||||
out,
|
||||
)
|
||||
# maximum and summation statistics
|
||||
if return_debug_tensors:
|
||||
nl.store(
|
||||
hbm_m_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
m_buffer[i, i_q_h, :, :],
|
||||
)
|
||||
nl.store(
|
||||
hbm_l_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
l_buffer[i, i_q_h],
|
||||
)
|
||||
nl.store(
|
||||
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||
qk_res_buffer[batch_id, i_q_h, :, :],
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
|
||||
return o
|
||||
|
||||
|
||||
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
||||
"""
|
||||
Reorder the mask to make it compatible with the flash attention kernel.
|
||||
|
||||
We vectorize KV cache read to improve DMA utilization. However, the layout
|
||||
that maximizes DMA bandwidth changes the order tokens are consumed.
|
||||
|
||||
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
||||
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
||||
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
||||
tokens. Therefore, the tokens are visited in a strided way.
|
||||
|
||||
To make sure mask matches the order tokens are consumed, we need to properly
|
||||
transpose mask.
|
||||
"""
|
||||
total_query_len, total_seq_len = mask.shape
|
||||
context_kv_len = total_seq_len - total_query_len
|
||||
|
||||
B_P_SIZE = 128
|
||||
assert (LARGE_TILE_SZ
|
||||
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
||||
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
||||
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
||||
if tiled_block_size > 1:
|
||||
# Mask reordering is needed when tiled_block_size > 1
|
||||
device = mask.device
|
||||
mask = mask.cpu()
|
||||
context_mask = mask[:, :context_kv_len]
|
||||
context_mask = context_mask.view(
|
||||
total_query_len,
|
||||
context_kv_len // LARGE_TILE_SZ,
|
||||
num_tiled_blocks // B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
tiled_block_size,
|
||||
)
|
||||
context_mask = context_mask.transpose(3, 4).reshape(
|
||||
total_query_len, context_kv_len)
|
||||
new_mask = mask[:, context_kv_len:]
|
||||
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def flash_attn_varlen_nkifunc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_table,
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
LARGE_TILE_SZ=2048,
|
||||
mixed_precision=True,
|
||||
):
|
||||
"""
|
||||
Compute flash paged attention for variable length sequences.
|
||||
|
||||
This function is a wrapper around the flash attention NKI kernel. It takes
|
||||
in the following arguments:
|
||||
- query: (1, n_heads, d, seq_q)
|
||||
- key: (1, n_kv_heads, d, seq_k)
|
||||
- value: (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (n_active_blocks, )
|
||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||
|
||||
Notes:
|
||||
- attn_mask must be reordered outside using `reorder_context_mask`
|
||||
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
||||
for better DMA throughput
|
||||
"""
|
||||
if n_kv_head is None:
|
||||
n_kv_head = kv_cache.shape[2]
|
||||
assert kv_cache.shape[0] == 2
|
||||
assert kv_cache.shape[2] == n_kv_head
|
||||
if head_size is None:
|
||||
head_size = kv_cache.shape[-1]
|
||||
|
||||
kwargs = dict(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_table,
|
||||
mask=attn_mask,
|
||||
softmax_scale=1.0 / (head_size**0.5),
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Writes key-value pairs to the KV cache at specified positions.
|
||||
|
||||
Args:
|
||||
key (torch.Tensor): Key tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
value (torch.Tensor): Value tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
||||
(2, num_blocks, n_kv_head, block_size, d_head)
|
||||
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||
with shape (num_tokens)
|
||||
|
||||
Returns:
|
||||
None: Updates the kv_cache tensor in-place
|
||||
"""
|
||||
block_size = kv_cache.size(3)
|
||||
n_kv_head = key.size(1)
|
||||
|
||||
# Calculate indices with explicit floor division
|
||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_offsets = slot_mapping % block_size
|
||||
|
||||
# Create the head indices tensor
|
||||
head_indices = torch.arange(n_kv_head, device=key.device)
|
||||
|
||||
# Update caches using index_put_
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([0], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), key)
|
||||
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([1], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), value)
|
||||
256
vllm/attention/ops/paged_attn.py
Normal file
256
vllm/attention/ops/paged_attn.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
902
vllm/attention/ops/prefix_prefill.py
Normal file
902
vllm/attention/ops/prefix_prefill.py
Normal file
@@ -0,0 +1,902 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
# a performance improvement, but dramatically increases first call latency in
|
||||
# triton 3.2. Because of this tradeoff, it's currently commented out.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
|
||||
# "num_unroll_cache": 4, \
|
||||
# "num_unroll_request": 1 } | \
|
||||
# ({"kpack": 2, "waves_per_eu": 2} \
|
||||
# if current_platform.is_rocm() else {}), \
|
||||
# num_warps=4, \
|
||||
# num_stages=1)
|
||||
# ],
|
||||
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
|
||||
# )
|
||||
@triton.jit
|
||||
def _fwd_kernel(Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl: tl.constexpr,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: tl.constexpr,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
# start position inside of the query
|
||||
# generally, N goes over kv, while M goes over query_len
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
# [BLOCK_SIZE]; starts at 0
|
||||
offs_bs_n = tl.arange(0, BLOCK_SIZE)
|
||||
# [N]; starts at 0
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# [D]; starts at 0
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
# [M]; starts at current position in query
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# [M,D]
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
|
||||
0).to(tl.int1) # [D]
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_query_len),
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
|
||||
loop_unroll_factor=num_unroll_cache):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
(start_n // BLOCK_SIZE) * stride_b_loc_s)
|
||||
# [D,BLOCK_SIZE]
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
offs_bs_n[:, None] * stride_v_cache_bl)
|
||||
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
k_load = tl.load(
|
||||
K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
else:
|
||||
k_load = tl.load(K_cache + off_k)
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
# (start_n + offs_bs_n[None, :]) are the positions of
|
||||
# KV entries in sequence
|
||||
# So the condition makes sure each entry in Q only attends
|
||||
# to KV entries not more than SLIDING_WINDOW away.
|
||||
#
|
||||
# We can't use -inf here, because the
|
||||
# sliding window may lead to the entire row being masked.
|
||||
# This then makes m_ij contain -inf, which causes NaNs in
|
||||
# exp().
|
||||
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
|
||||
(start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk,
|
||||
-10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
if start_n + BLOCK_SIZE > cur_batch_ctx_len or \
|
||||
BLOCK_DMODEL != BLOCK_DMODEL_PADDED:
|
||||
v_load = tl.load(
|
||||
V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0) # [N,D]
|
||||
else:
|
||||
v_load = tl.load(V_cache + off_v)
|
||||
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# block_mask is 0 when we're already past the current query length
|
||||
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
|
||||
|
||||
# compute query against itself (with causal mask)
|
||||
for start_n in tl.range(0, \
|
||||
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
|
||||
loop_unroll_factor=num_unroll_request):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
if SLIDING_WINDOW > 0:
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk, -10000)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_query_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // num_queries_per_kv
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||
cur_batch_in_all_start_index)
|
||||
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||
|
||||
if SKIP_DECODE and cur_batch_query_len == 1:
|
||||
return
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (bn[:, None] * stride_v_cache_bs +
|
||||
cur_kv_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k_load = tl.load(K_cache + off_k,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
|
||||
if k_load.dtype.is_fp8():
|
||||
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v_load = tl.load(V_cache + off_v,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
if v_load.dtype.is_fp8():
|
||||
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
|
||||
else:
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debugger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=dim_mask[:, None] & ((start_n + offs_n[None, :])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi,
|
||||
float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=dim_mask[None, :] & ((start_n + offs_n[:, None])
|
||||
< cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
kv_cache_dtype: str,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
if "fp8" in kv_cache_dtype:
|
||||
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
target_dtype = current_platform.fp8_dtype()
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
target_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
|
||||
|
||||
k_cache = k_cache.view(target_dtype)
|
||||
v_cache = v_cache.view(target_dtype)
|
||||
|
||||
if (k_cache.dtype == torch.uint8
|
||||
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
|
||||
raise ValueError("kv_cache_dtype='auto' unsupported for\
|
||||
FP8 KV Cache prefill kernel")
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
assert batch + 1 == len(b_start_loc)
|
||||
|
||||
# 0 means "disable"
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
# batch, head,
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
SKIP_DECODE=skip_decode,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head,
|
||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
**extra_kargs)
|
||||
return
|
||||
100
vllm/attention/ops/rocm_aiter_mla.py
Normal file
100
vllm/attention/ops/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
max_block_per_batch: int,
|
||||
device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
|
||||
kv_buffer.view(
|
||||
-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=[torch.Tag.needs_fixed_stride_order])
|
||||
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
102
vllm/attention/ops/rocm_aiter_paged_attn.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AITERPagedAttention(PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
else:
|
||||
kv_cache_torch_dtype = (FP8_DTYPE
|
||||
if "fp8" in kv_cache_dtype else torch.int8)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
|
||||
rocm_aiter.reshape_and_cache_with_pertoken_quant(
|
||||
key, value, key_cache, value_cache, k_scale, v_scale,
|
||||
slot_mapping.flatten(), True)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
|
||||
return PagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
num_kv_heads=num_kv_heads,
|
||||
scale=scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
tp_rank=tp_rank,
|
||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||
blocksparse_block_size=blocksparse_block_size,
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
|
||||
|
||||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
|
||||
seq_lens, max_num_blocks_per_seq, k_scale,
|
||||
v_scale, output)
|
||||
return output
|
||||
685
vllm/attention/ops/triton_decode_attention.py
Normal file
685
vllm/attention/ops/triton_decode_attention.py
Normal file
@@ -0,0 +1,685 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||
|
||||
# Changes:
|
||||
# - Add support for page size >= 1.
|
||||
|
||||
# Copyright 2025 vLLM Team
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size >= 1.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# is_hip_ = current_platform.is_rocm()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only print the following warnings when triton version < 3.2.0.
|
||||
# The issue won't affect performance or accuracy.
|
||||
if triton.__version__ < '3.2.0':
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored.")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = -float("inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[None, :])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
acc *= re_scale
|
||||
acc += tl.sum(p[:, None] * v, 0)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum,
|
||||
mask=(mask_dv),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
)
|
||||
|
||||
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
# BLOCK = 64 if not is_hip_ else 8
|
||||
BLOCK = 8
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4
|
||||
if kv_group_num != 1:
|
||||
# num_warps = 1 if is_hip_ else 2
|
||||
num_warps = 1
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
if kv_group_num > BLOCK_H:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
|
||||
None, :]
|
||||
q = tl.load(Q + offs_q,
|
||||
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||
other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
mask_dpe = offs_dpe < Lk
|
||||
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
|
||||
offs_dpe[None, :])
|
||||
qpe = tl.load(Q + off_qpe,
|
||||
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||
other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[:, None])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh +
|
||||
offs_dpe[:, None])
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) &
|
||||
(mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
|
||||
qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob +
|
||||
cur_head[:, None] * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv[None, :])
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum[:, None],
|
||||
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
mask=mask_h,
|
||||
)
|
||||
|
||||
|
||||
def _decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
num_stages,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 16
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
# if is_hip_ and Lk >= 576:
|
||||
# BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lk == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
BLOCK_H = 16
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
NUM_KV_SPLITS,
|
||||
)
|
||||
|
||||
if num_stages == 1:
|
||||
extra_kargs = {"scenario":"mla"}
|
||||
elif num_stages == 2:
|
||||
extra_kargs = {"scenario" : "mla", "pipeline" : "cpasync"}
|
||||
else:
|
||||
KeyError("num_stages should be 1 or 2")
|
||||
# if is_hip_:
|
||||
# # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
|
||||
# # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
# extra_kargs = {
|
||||
# "waves_per_eu": 1,
|
||||
# "matrix_instr_nonkdim": 16,
|
||||
# "kpack": 2
|
||||
# }
|
||||
# num_stages = 1
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lv
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
|
||||
mask=mask_d,
|
||||
other=0.0)
|
||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||
n_e_max = tl.maximum(tlogic, e_max)
|
||||
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(tlogic - n_e_max)
|
||||
acc += exp_logic * tv
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
# if is_hip_:
|
||||
# # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
# extra_kargs = {
|
||||
# "waves_per_eu": 4,
|
||||
# "matrix_instr_nonkdim": 16,
|
||||
# "kpack": 2
|
||||
# }
|
||||
|
||||
grid = (batch, head_num)
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
num_stages,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
num_stages,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
num_stages,
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
num_stages,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
979
vllm/attention/ops/triton_flash_attention.py
Normal file
979
vllm/attention/ops/triton_flash_attention.py
Normal file
@@ -0,0 +1,979 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import on_gfx1x
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if USE_FP8:
|
||||
qk *= qk_scale
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
|
||||
if USE_FP8:
|
||||
p *= p_descale
|
||||
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 256,
|
||||
'BLOCK_N': 128,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 1,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': True
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 128,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 3,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 64,
|
||||
'BLOCK_N': 64,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# "BLOCK_M": 16,
|
||||
# "BLOCK_N": 16,
|
||||
# "waves_per_eu": 1,
|
||||
# "PRE_LOAD_V": False,
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 32,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 4,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_M': 32,
|
||||
'BLOCK_N': 16,
|
||||
'waves_per_eu': 2,
|
||||
'PRE_LOAD_V': False
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=2),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 4,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 2,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# # Fall-back config.
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if on_gfx1x():
|
||||
return get_rdna_autotune_configs()
|
||||
else:
|
||||
return get_cdna_autotune_configs()
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz: tl.int64,
|
||||
stride_qh: tl.int64,
|
||||
stride_qm: tl.int64,
|
||||
stride_qk: tl.int64,
|
||||
stride_kz: tl.int64,
|
||||
stride_kh: tl.int64,
|
||||
stride_kn: tl.int64,
|
||||
stride_kk: tl.int64,
|
||||
stride_vz: tl.int64,
|
||||
stride_vh: tl.int64,
|
||||
stride_vk: tl.int64,
|
||||
stride_vn: tl.int64,
|
||||
stride_oz: tl.int64,
|
||||
stride_oh: tl.int64,
|
||||
stride_om: tl.int64,
|
||||
stride_on: tl.int64,
|
||||
stride_bz: tl.int64,
|
||||
stride_bh: tl.int64,
|
||||
stride_bm: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
USE_FP8_OUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
if not USE_FP8:
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
acc_scale = 1.0
|
||||
else:
|
||||
qk_scale *= q_scale * k_scale
|
||||
acc_scale = p_scale * v_scale
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
# epilogue
|
||||
|
||||
if USE_FP8:
|
||||
acc *= acc_scale
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
if USE_FP8_OUT:
|
||||
acc *= o_descale
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None]
|
||||
>= out_mask_boundary[None, :])
|
||||
z = tl.zeros((1, ), tl.float32)
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
fp8_scales=None,
|
||||
fp8_out_scale=None,
|
||||
):
|
||||
if fp8_scales is not None:
|
||||
use_fp8 = True
|
||||
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||
float8 = current_platform.fp8_dtype()
|
||||
|
||||
def check_and_convert(t, scale):
|
||||
if t.dtype != float8:
|
||||
descale = 1.0 / scale
|
||||
ts = (t * descale).clamp(min=float8_info.min,
|
||||
max=float8_info.max)
|
||||
return ts.to(float8)
|
||||
else:
|
||||
return t
|
||||
|
||||
q = check_and_convert(q, q_scale)
|
||||
k = check_and_convert(k, k_scale)
|
||||
v = check_and_convert(v, v_scale)
|
||||
else:
|
||||
use_fp8 = False
|
||||
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
p_descale = 1.0 / p_scale
|
||||
o_descale = 1.0 / fp8_out_scale.item(
|
||||
) if fp8_out_scale is not None else 1.0
|
||||
|
||||
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
USE_FP8=use_fp8,
|
||||
USE_FP8_OUT=fp8_out_scale is not None,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
97
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
# Will reuse precomputed Exp values for scale factor computation.
|
||||
p_se = tl.exp(p_lse)
|
||||
s_se = tl.exp(s_lse)
|
||||
out_se = (p_se + s_se)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
353
vllm/attention/ops/triton_unified_attention.py
Normal file
353
vllm/attention/ops/triton_unified_attention.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Authors:
|
||||
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
||||
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
||||
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_softcap(S, x):
|
||||
Sdiv = S / x
|
||||
p1 = tl.exp(Sdiv)
|
||||
p2 = tl.exp(-Sdiv)
|
||||
return x * (p1 - p2) / (p1 + p2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
):
|
||||
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
left: tl.int32 = 0
|
||||
right = num_seqs
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
mid_val = tl.load(query_start_len_ptr + mid) // BLOCK_Q + mid
|
||||
if mid_val <= q_block_global_idx:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
seq_idx = left - 1
|
||||
q_block_start_idx = tl.load(query_start_len_ptr +
|
||||
seq_idx) // BLOCK_Q + seq_idx
|
||||
|
||||
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
||||
|
||||
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
||||
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
||||
|
||||
cur_batch_query_len = cur_batch_in_all_stop_index \
|
||||
- cur_batch_in_all_start_index
|
||||
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
# M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
|
||||
# context length for this particular sequences
|
||||
context_len = seq_len - cur_batch_query_len
|
||||
|
||||
# alibi slope for this head
|
||||
if USE_ALIBI_SLOPES:
|
||||
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
offs_n[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
offs_n[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
|
||||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
||||
S, float("-inf"))
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
||||
< SLIDING_WINDOW, S, float("-inf"))
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, BLOCK_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
|
||||
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||
query_offset_1[:, None] * output_stride_1 +
|
||||
offs_d[None, :])
|
||||
|
||||
tl.store(
|
||||
output_ptr + output_offset,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
block_table,
|
||||
softcap,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
|
||||
block_size = v.shape[1]
|
||||
assert q.element_size() >= 2 or block_size >= 32, \
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
block_size = v.shape[1]
|
||||
num_seqs = len(seqused_k)
|
||||
num_query_heads = q.shape[1]
|
||||
num_kv_heads = k.shape[2]
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_size = q.shape[2]
|
||||
|
||||
BLOCK_M = 16
|
||||
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
||||
|
||||
# Ideally we would launch with kernel with:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
||||
# However, it is slow to realize the query_lens on cpu.
|
||||
# Instead we use upper-bound:
|
||||
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
||||
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
||||
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
||||
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
kernel_unified_attention_2d[(
|
||||
total_num_q_blocks,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=out,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seqused_k,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=q.stride(0),
|
||||
query_stride_1=q.stride(1),
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
stride_k_cache_2=k.stride(2),
|
||||
stride_k_cache_3=k.stride(3),
|
||||
stride_v_cache_0=v.stride(0),
|
||||
stride_v_cache_1=v.stride(1),
|
||||
stride_v_cache_2=v.stride(2),
|
||||
stride_v_cache_3=v.stride(3),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
)
|
||||
Reference in New Issue
Block a user