This commit is contained in:
root
2026-04-09 11:19:36 +08:00
parent 809cecae09
commit 8082d5f4b2
2579 changed files with 3675 additions and 0 deletions

View File

View File

@@ -0,0 +1,401 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Authors:
# - Burkhard Ringlein <ngl@zurich.ibm.com>
# - Jan van Lunteren <jvl@zurich.ibm.com>
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
if filter_by_query_len:
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if cur_batch_query_len > 1:
return
else:
cur_batch_in_all_start_index = seq_idx
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
0, num_queries_per_kv_padded
)
query_offset = (
cur_batch_in_all_start_index * query_stride_0
+ query_head_idx[:, None] * query_stride_1
)
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
head_mask = head_mask & (query_head_idx < num_query_heads)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
# Q : (num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(
alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0
)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
v_offset = (
physical_block_idx * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_d[None, :] * stride_v_cache_2
+ offs_n[:, None] * stride_v_cache_3
)
k_offset = (
physical_block_idx * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ (offs_d[:, None] // x) * stride_k_cache_2
+ offs_n[None, :] * stride_k_cache_3
+ (offs_d[:, None] % x) * stride_k_cache_4
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0)
if K_load.dtype.is_fp8():
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0)
if V_load.dtype.is_fp8():
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
seq_mask = seq_offset[None, :] < boundary
# S : (num_queries_per_kv, BLOCK_SIZE,)
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32)
S += scale * tl.dot(Q, K)
context_len = seq_len - 1
if SLIDING_WINDOW > 0:
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1)
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (
cur_batch_in_all_start_index * output_stride_0
+ query_head_idx * output_stride_1
)
tl.store(
output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
acc,
mask=dim_mask[None, :] & head_mask[:, None],
)
def chunked_prefill_paged_decode(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_table,
query_start_loc,
seq_lens,
max_seq_len,
max_query_len,
k_scale,
v_scale,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
output_scale=None,
# Optional tensor for sinks
sinks=None,
):
if sm_scale is None:
sm_scale = 1.0 / (query.shape[1] ** 0.5)
use_alibi_slopes = alibi_slopes is not None
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if max_query_len > 1:
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
kv_cache_dtype=kv_cache_dtype,
k_cache=key_cache,
v_cache=value_cache,
b_loc=block_table,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_seq_len=max_seq_len,
max_input_len=max_query_len,
k_scale=k_scale,
v_scale=v_scale,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
)
block_size = value_cache.shape[3]
num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
head_size = query.shape[2]
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
key_cache = key_cache.view(target_dtype)
value_cache = value_cache.view(target_dtype)
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16)
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
query.dtype,
head_size,
block_size,
num_queries_per_kv,
max_seq_len,
sliding_window,
kv_cache_dtype,
alibi_slopes,
sinks,
)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = (
max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = block_table.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale=sm_scale,
block_tables=block_table,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
block_size=block_size,
max_seq_len=max_seq_len,
alibi_slopes=alibi_slopes,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=output_scale,
)
else:
kernel_paged_attention_2d[
(
num_seqs,
num_kv_heads,
)
](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
)

View File

@@ -0,0 +1,414 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(
outputs_ptr,
new_output_ptr,
lses_ptr,
vlse_ptr,
outputs_stride_B,
outputs_stride_H,
outputs_stride_D,
lses_stride_N,
lses_stride_B,
lses_stride_H,
lse_idx,
HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr,
):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
outputs_ptr (triton.PointerType):
Pointer to input tensor of shape [ B, H, D ]
lses_ptr (triton.PointerType):
Pointer to input tensor of shape [ N, B, H ]
new_output_ptr (triton.PointerType):
Pointer to output tensor of shape [ B, H, D ]
vlse_ptr (triton.PointerType):
Pointer to output tensor of shape [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = (
num_n_offsets * lses_stride_N
+ batch_idx * lses_stride_B
+ head_idx * lses_stride_H
)
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = (
batch_idx * outputs_stride_B
+ head_idx * outputs_stride_H
+ d_offsets * outputs_stride_D
)
# correct output
lse_offset = (
lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
)
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float("inf")),
-float("inf"),
lse_finally,
)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
"""The CPTritonContext is used to avoid recompilation of the Triton JIT."""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(
out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext
) -> tuple[torch.Tensor, torch.Tensor]:
"""Correct the attention output using the all-gathered lses.
Args:
out: Tensor of shape [ B, H, D ]
lses: Tensor of shape [ N, B, H ]
cp_rank: Current rank in the context-parallel group
ctx: Triton context to avoid recompilation
Returns:
Tuple of (out, lse) with corrected attention and final log-sum-exp.
"""
if ctx is None:
ctx = CPTritonContext()
# --- Normalize to 3D views ---
if out.ndim == 4 and out.shape[1] == 1:
out = out.squeeze(1)
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
if lses.ndim == 4 and lses.shape[-1] == 1:
lses = lses.squeeze(-1)
if lses.ndim == 4 and lses.shape[1] == 1:
lses = lses.squeeze(1)
assert lses.ndim == 3, (
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
f"got {tuple(lses.shape)}"
)
B, H, D = out.shape
N = lses.shape[0]
# Strides after we normalized shapes to 3-D views. The kernel computes
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
# have the same B/H stride layout as a slice of `lses`.
o_sB, o_sH, o_sD = out.stride()
l_sN, l_sB, l_sH = lses.stride()
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
lse = torch.empty_strided(
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
# Kernel launch config
grid = (B, H, 1)
regular_args = (
out,
out,
lses,
lse,
o_sB,
o_sH,
o_sD,
l_sN,
l_sB,
l_sH,
cp_rank,
)
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
return out, lse
def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty(
(cp_group.world_size,) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device,
)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
cp_num_heads = lse.shape[1] // cp_group.world_size
cp_rank = cp_group.rank_in_group
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
return out, lse
return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
out_ptr, # [B, Lmax, D]
lengths_ptr, # *i32, [B]
N: tl.constexpr,
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr, # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
# compute input row indices for valid (b, t)
in_row = in_start + off_t
valid_row = (off_t < seq_len) & t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
def pack_seq_triton(
x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float("inf"),
block_t: int = 64,
block_d: int = 64,
) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
if len(original_shape) > 2:
N = original_shape[0]
x_reshaped = x.reshape(N, -1)
D = x_reshaped.shape[1]
else:
N, D = x.shape
x_reshaped = x
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_pack_seq_kernel[grid](
x_reshaped,
out,
lengths.int(),
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
return out
@triton.jit
def _unpack_seq_triton_kernel(
packed_ptr, # [B, Lmax, D]
out_ptr, # [N, D]
lengths_ptr, # *i32, [B]
B: tl.constexpr,
Lmax: tl.constexpr,
D: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr, # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
valid_row = (off_t < seq_len) & t_mask
# compute output row indices for valid (b, t)
out_row = in_start + off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
# out_ptr: row-major [N, D]
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
# Load from packed tensor and store to output
d_mask = off_d[None, :] < D
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
def unpack_seq_triton(
packed_tensor: torch.Tensor,
lengths: torch.Tensor,
block_t: int = 64,
block_d: int = 64,
) -> torch.Tensor:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape = packed_tensor.shape
if len(original_shape) > 3:
B, Lmax = original_shape[:2]
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
D = packed_reshaped.shape[2]
else:
B, Lmax, D = packed_tensor.shape
packed_reshaped = packed_tensor
# Calculate total number of elements
N = int(lengths.sum().item())
out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_unpack_seq_triton_kernel[grid](
packed_reshaped,
out,
lengths.int(),
B,
Lmax,
D,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 3:
output_shape = (N,) + original_shape[2:]
out = out.reshape(output_shape)
return out

View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm import _custom_ops as ops
logger = init_logger(__name__)
if current_platform.is_cuda():
try:
import vllm._flashmla_C # noqa: F401
_flashmla_C_AVAILABLE = True
except ImportError:
_flashmla_C_AVAILABLE = False
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def _is_flashmla_available() -> tuple[bool, str | None]:
if not _flashmla_C_AVAILABLE:
return (
False,
"vllm._flashmla_C is not available, likely was not "
"compiled due to insufficient nvcc version or a supported arch "
"was not in the list of target arches to compile for.",
)
if not _flashmla_extension_C_AVAILABLE:
return (
False,
"vllm._flashmla_extension_C is not available, likely "
"was not compiled due to a build error.",
)
return True, None
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] not in (9, 10):
return (
False,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
)
return True, None
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: int | None = None,
is_fp8_kvcache: bool = False,
topk: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None,
is_fp8_kvcache: bool = False,
indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, "causal must be `false` if sparse attention is enabled."
assert (descale_q is None) == (descale_k is None), (
"descale_q and descale_k should be both None or both not None"
)
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v)
return results
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
#

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
) -> None:
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# is not support for FP8 dtype, fallback to use Triton kernel.
def supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# kernel load/store 128b(16 bytes) per memory issue within
# thread. Namely, the headsize(headdim) must be multiple of
# pack_size (float32 -> 4, half/bfloat16 -> 8).
def supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
if (
current_platform.is_cuda()
and supported_dtypes(output)
and supported_headdim(output)
):
from vllm._custom_ops import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
)
else:
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
)

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: torch.Tensor | None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: torch.Tensor | None
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = max_seq_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
max_query_len: int,
alibi_slopes: torch.Tensor | None,
sliding_window: int | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
max_seq_len = None
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
# query_start_loc is (batch_size + 1,)
query_start_loc,
seq_lens_tensor,
max_seq_len,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from vllm.utils.math_utils import cdiv
def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
num_slices_ref, # [1]
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
num_slices_per_block = scratch.shape[0]
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0
)
length = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
)
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0
)
length = jax.lax.select(
offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0
)
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
@functools.partial(
jax.jit,
static_argnames=["page_size", "num_slices_per_block"],
)
def kv_cache_update(
# [total_num_token, num_combined_kv_heads, head_dim]
new_kv: jax.Array,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
slices: jax.Array,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
kv_cache: jax.Array,
# [1]
num_kv_update_slices: jax.Array,
*,
page_size: int = 32,
num_slices_per_block: int = 8,
):
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
assert head_dim % 128 == 0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices, num_kv_update_slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)
scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]
kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
)
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]

View File

@@ -0,0 +1,814 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
# Here's an example autotuner config for this kernel. This config does provide
# a performance improvement, but dramatically increases first call latency in
# triton 3.2. Because of this tradeoff, it's currently commented out.
# @triton.autotune(
# configs=[
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \
# "num_unroll_cache": 4, \
# "num_unroll_request": 1 } | \
# ({"kpack": 2, "waves_per_eu": 2} \
# if current_platform.is_rocm() else {}), \
# num_warps=4, \
# num_stages=1)
# ],
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
# )
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
sink_ptr,
B_Loc,
sm_scale,
k_scale,
v_scale,
out_scale_inv,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl: tl.constexpr,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [BLOCK_SIZE]; starts at 0
offs_bs_n = tl.arange(0, BLOCK_SIZE)
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
tl.int1
) # [D]
q = tl.load(
Q + off_q,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len),
other=0.0,
) # [M,D]
# initialize pointer to m and l
if not USE_SINKS:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len),
other=float("-inf"),
).to(dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in tl.range(
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(
B_Loc
+ cur_batch * stride_b_loc_b
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
).to(tl.int64)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs
+ cur_kv_head * stride_k_cache_h
+ (offs_d[:, None] // x) * stride_k_cache_d
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl
+ (offs_d[:, None] % x) * stride_k_cache_x
)
# [BLOCK_SIZE,D]
off_v = (
bn[:, None] * stride_v_cache_bs
+ cur_kv_head * stride_v_cache_h
+ offs_d[None, :] * stride_v_cache_d
+ offs_bs_n[:, None] * stride_v_cache_bl
)
if (
start_n + BLOCK_SIZE > cur_batch_ctx_len
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
):
k_load = tl.load(
K_cache + off_k,
mask=dim_mask[:, None]
& ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
other=0.0,
) # [D,N]
else:
k_load = tl.load(K_cache + off_k)
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where(
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
)
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where(
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
< SLIDING_WINDOW,
qk,
-10000,
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
if (
start_n + BLOCK_SIZE > cur_batch_ctx_len
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
):
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :]
& ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0,
) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (
offs_n[None, :] * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None] * stride_kd
)
off_v = (
offs_n[:, None] * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :] * stride_vd
)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in tl.range(
0,
block_mask * (start_m + 1) * BLOCK_M,
BLOCK_N,
loop_unroll_factor=num_unroll_request,
):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None]
& ((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk,
-10000,
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :]
& ((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0,
)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(
out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)
)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SKIP_DECODE: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
tl.int1
)
q = tl.load(
Q + off_q,
mask=dim_mask[None, :]
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0,
)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(
B_Loc
+ cur_batch * stride_b_loc_b
+ ((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0,
).to(tl.int64)
off_k = (
bn[None, :] * stride_k_cache_bs
+ cur_kv_head * stride_k_cache_h
+ (offs_d[:, None] // x) * stride_k_cache_d
+ ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl
+ (offs_d[:, None] % x) * stride_k_cache_x
)
off_v = (
bn[:, None] * stride_v_cache_bs
+ cur_kv_head * stride_v_cache_h
+ offs_d[None, :] * stride_v_cache_d
+ (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl
)
k_load = tl.load(
K_cache + off_k,
mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0,
) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where(
(start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
)
qk *= sm_scale
# load alibi
alibi = (
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi,
float("-inf"),
)
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0,
)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (
offs_n[None, :] * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None] * stride_kd
)
off_v = (
offs_n[:, None] * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :] * stride_vd
)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None]
& ((start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision="ieee")
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# load alibi
alibi = (
tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]
) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi,
float("-inf"),
)
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :]
& ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0,
)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision="ieee")
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(
out_ptrs,
acc,
mask=dim_mask[None, :]
& (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
)
return
@torch.inference_mode()
def context_attention_fwd(
q,
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False,
fp8_out_scale=None,
sinks=None,
):
q_dtype_is_f32 = q.dtype is torch.float32
# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = "ieee" if IS_TURING and q_dtype_is_f32 else None
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (
k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8
and kv_cache_dtype == "auto"
):
raise ValueError(
"kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel"
)
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
alibi_slopes,
v_cache.shape[3],
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SKIP_DECODE=skip_decode,
num_warps=NUM_WARPS,
num_stages=1,
)
return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
sinks,
b_loc,
sm_scale,
k_scale,
v_scale,
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
b_start_loc,
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
num_unroll_request=1,
num_warps=4,
num_stages=1,
USE_SINKS=sinks is not None,
**extra_kargs,
)
return

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import aiter as rocm_aiter
import torch
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
FP8_DTYPE = current_platform.fp8_dtype()
class AITERPagedAttention(PagedAttention):
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key,
value,
key_cache,
value_cache,
k_scale,
v_scale,
slot_mapping.flatten(),
True,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)
output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
rocm_aiter.pa_fwd_asm(
query,
key_cache,
value_cache,
block_tables,
seq_lens,
max_num_blocks_per_seq,
k_scale,
v_scale,
output,
)
return output

View File

@@ -0,0 +1,712 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import logging
from packaging import version
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
is_hip_ = current_platform.is_rocm()
logger = logging.getLogger(__name__)
# Only print the following warnings when triton version < 3.2.0.
# The issue won't affect performance or accuracy.
if version.parse(triton.__version__) < version.parse("3.2.0"):
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored."
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_kv_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = -float("inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
+ stride_req_to_tokens_b * cur_batch_req_idx
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[None, :]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
acc *= re_scale
acc += tl.sum(p[:, None] * v, 0)
e_sum = e_sum * re_scale + tl.sum(p, 0)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum,
mask=(mask_dv),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
)
def _decode_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 64 if not is_hip_ else 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]
num_warps = 4
if kv_group_num != 1:
num_warps = 1 if is_hip_ else 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
_fwd_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=2,
Lk=Lk,
Lv=Lv,
)
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens
+ stride_req_to_tokens_b * cur_batch_req_idx
+ offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:, None] * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv[None, :]
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
BLOCK = 16
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
extra_kargs = {}
num_stages = 2
if is_hip_:
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
num_stages = 1
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
)
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
o,
lse,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
stride_lse_bs,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
lse_val = e_max + tl.log(e_sum)
tl.store(
lse + cur_batch * stride_lse_bs + cur_head,
lse_val,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
lse,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
o,
lse,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
**extra_kargs,
)
def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
)
def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
)
def decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size=1,
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
# GQA/MQA/MLA
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)

View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
# Will reuse precomputed Exp values for scale factor computation.
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = p_se + s_se
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask,
)

View File

@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
key_load = tl.load(
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
)
if FP8_KV_CACHE:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(
value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
)
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return
def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size
key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
else key_cache.dtype
)
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, (
"explicit fp8 cast and store to "
"uint8 is not supported by triton reshape_and_cache_flash"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
torch.float8_e4m3fnuz,
], (
"unsupported dtype of KV cache tensor, got "
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
)
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
)
reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)

View File

@@ -0,0 +1,941 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Authors:
# - Burkhard Ringlein <ngl@zurich.ibm.com>
# - Jan van Lunteren <jvl@zurich.ibm.com>
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def apply_softcap(S, x):
Sdiv = S / x
p1 = tl.exp(Sdiv)
p2 = tl.exp(-Sdiv)
return x * (p1 - p2) / (p1 + p2)
@triton.jit
def find_seq_idx(
query_start_len_ptr,
target_idx,
num_seqs,
BLOCK_Q: tl.constexpr,
use_q_block_mode: tl.constexpr,
):
left: tl.int32 = 0
right = num_seqs
while left < right:
mid = (left + right) // 2
val = tl.load(query_start_len_ptr + mid)
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
if mid_val <= target_idx:
left = mid + 1
else:
right = mid
return left - 1
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int must be power of 2
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = (
query_offset_0[:, None] * query_stride_0
+ query_offset_1[:, None] * query_stride_1
+ offs_d[None, :]
)
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if not USE_SINKS:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(
alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0
)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = (
context_len
+ q_block_local_idx * BLOCK_Q
+ (BLOCK_M - 1) // num_queries_per_kv
+ 1
)
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# ---- Sliding-window tile pruning --------------------
# Default: keep previous global behavior
tile_start = 0
tile_end = num_tiles
if SLIDING_WINDOW > 0:
# Query rows covered by this Q-block
qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum(
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
cur_batch_query_len - 1,
)
# For sliding window, each query position q can only attend to
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
# where q_abs = context_len + q
# The union of allowed key positions for this Q-block is:
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
last_allowed_key = context_len + qpos_hi
# Convert to tile indices and clamp
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
# iterate through tiles (now limited to the sliding window range)
for j in range(tile_start, tile_end):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
S = tl.where(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_M, TILE_SIZE)
P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1)
# alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (
query_offset_0[:, None] * output_stride_0
+ query_offset_1[:, None] * output_stride_1
+ offs_d[None, :]
)
tl.store(
output_ptr + output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)
seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
return
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = (
query_offset_0[:, None] * query_stride_0
+ query_offset_1[:, None] * query_stride_1
+ offs_d[None, :]
)
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if USE_SINKS:
if segm_idx == 0:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
else:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(
alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0
)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = (
context_len
+ q_block_local_idx * BLOCK_Q
+ (BLOCK_M - 1) // num_queries_per_kv
+ 1
)
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# iterate through tiles within current segment
for j in range(
segm_idx * tiles_per_segment,
min((segm_idx + 1) * tiles_per_segment, num_tiles),
):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
S = tl.where(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_M, TILE_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1)
# alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
segm_output_offset = (
query_offset_0[:, None].to(tl.int64)
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ segm_idx * HEAD_SIZE_PADDED
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
)
tl.store(
segm_output_ptr + segm_output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
segm_offset = (
query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_offset_1 * NUM_SEGMENTS_PER_SEQ
+ segm_idx
)
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1)
@triton.jit
def reduce_segments(
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
# [num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
out_scale_inv, # float32
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
TILE_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
query_token_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
seq_idx = find_seq_idx(
query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False
)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
# create masks for subsequent loads
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32
)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1)
# load segment maxima
segm_offset = (
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_head_idx * NUM_SEGMENTS_PER_SEQ
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)
)
segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf"))
overall_max = tl.max(segm_max)
# load and rescale segment exp sums
segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0)
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
overall_expsum = tl.sum(segm_expsum)
# load, rescale, and add segment attention outputs
segm_output_offset = (
query_token_idx.to(tl.int64)
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
)
segm_output = tl.load(
segm_output_ptr + segm_output_offset,
mask=segm_mask[:, None] & dim_mask[None, :],
other=0.0,
)
segm_output *= tl.exp(segm_max - overall_max)[:, None]
acc_sum = tl.sum(segm_output, axis=0)
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
# write result
output_offset = (
query_token_idx * output_stride_0
+ query_head_idx * output_stride_1
+ tl.arange(0, HEAD_SIZE_PADDED)
)
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
def unified_attention(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
softmax_scale,
causal,
window_size,
block_table,
softcap,
q_descale,
k_descale,
v_descale,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
# Optional tensor for sinks
sinks=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
if sinks is not None:
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
block_size = v.shape[1]
num_seqs = len(seqused_k)
num_query_heads = q.shape[1]
num_kv_heads = k.shape[2]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = (
16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
)
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with:
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
# However, it is slow to realize the query_lens on cpu.
# Instead we use upper-bound:
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
kernel_unified_attention_2d[
(
total_num_q_blocks,
num_kv_heads,
)
](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16
segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
)
reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0),
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
USE_FP8=output_scale is not None,
)

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains ops for ViT attention to be compatible with torch.compile
as there are operations here not supported by torch.compile (for instance,
`to_list` in xformers attn, or `.item()` in flash attention)
Using these ops and wrapping vision blocks with `torch.compile` can speed up
throughput in vision models by ~5% relative on H100, and improve token
latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
"""
import einops
import torch
import torch.nn.functional as F
from vllm.utils.torch_utils import direct_register_custom_op
def xformers_attn_seqlens_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None
)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def xformers_attn_seqlens_wrapper_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="xformers_attn_seqlens_wrapper",
op_func=xformers_attn_seqlens_wrapper,
fake_impl=xformers_attn_seqlens_wrapper_fake,
)
def vit_xformers_attn_wrapper(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer
def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
)
# TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops
def torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
def torch_sdpa_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op(
op_name="torch_sdpa_wrapper",
op_func=torch_sdpa_wrapper,
fake_impl=torch_sdpa_wrapper_fake,
)
def vit_torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)