739 lines
26 KiB
Python
739 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Authors:
|
|
# - Burkhard Ringlein <ngl@zurich.ibm.com>
|
|
# - Jan van Lunteren <jvl@zurich.ibm.com>
|
|
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
|
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@triton.jit
|
|
def cdiv_fn(x, y):
|
|
return (x + y - 1) // y
|
|
|
|
|
|
@triton.jit
|
|
def apply_softcap(S, x):
|
|
Sdiv = S / x
|
|
p1 = tl.exp(Sdiv)
|
|
p2 = tl.exp(-Sdiv)
|
|
return x * (p1 - p2) / (p1 + p2)
|
|
|
|
|
|
@triton.jit
|
|
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
|
BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr):
|
|
left: tl.int32 = 0
|
|
right = num_seqs
|
|
while left < right:
|
|
mid = (left + right) // 2
|
|
val = tl.load(query_start_len_ptr + mid)
|
|
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
|
|
|
|
if mid_val <= target_idx:
|
|
left = mid + 1
|
|
else:
|
|
right = mid
|
|
|
|
return left - 1
|
|
|
|
|
|
@triton.jit
|
|
def kernel_unified_attention_2d(
|
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
|
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
|
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
|
seq_lens_ptr, # [num_seqs]
|
|
alibi_slopes_ptr, # [num_query_heads]
|
|
scale, # float32
|
|
k_scale, # float32
|
|
v_scale, # float32
|
|
softcap, # float32
|
|
num_query_heads: tl.constexpr, # int
|
|
num_queries_per_kv: tl.constexpr, # int
|
|
block_table_stride: tl.int64, # int
|
|
query_stride_0: tl.int64, # int
|
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
|
output_stride_0: tl.int64, # int
|
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
|
BLOCK_SIZE: tl.constexpr, # int
|
|
HEAD_SIZE: tl.constexpr, # int
|
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
|
USE_SOFTCAP: tl.constexpr, # bool
|
|
SLIDING_WINDOW: tl.constexpr, # int
|
|
stride_k_cache_0: tl.int64, # int
|
|
stride_k_cache_1: tl.int64, # int
|
|
stride_k_cache_2: tl.int64, # int
|
|
stride_k_cache_3: tl.constexpr, # int
|
|
stride_v_cache_0: tl.int64, # int
|
|
stride_v_cache_1: tl.int64, # int
|
|
stride_v_cache_2: tl.int64, # int
|
|
stride_v_cache_3: tl.constexpr, # int
|
|
query_start_len_ptr, # [num_seqs+1]
|
|
BLOCK_Q: tl.constexpr, # int
|
|
num_seqs: tl.int32,
|
|
BLOCK_M: tl.constexpr, # int
|
|
):
|
|
q_block_global_idx = tl.program_id(0)
|
|
kv_head_idx = tl.program_id(1)
|
|
|
|
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
|
BLOCK_Q, True)
|
|
|
|
q_block_start_idx = tl.load(query_start_len_ptr +
|
|
seq_idx) // BLOCK_Q + seq_idx
|
|
|
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
|
|
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
|
|
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
|
- cur_batch_in_all_start_index
|
|
|
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
|
return
|
|
|
|
offs_m = tl.arange(0, BLOCK_M)
|
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
|
|
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
|
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
|
offs_m % num_queries_per_kv
|
|
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
|
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
|
|
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
|
|
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
Q = tl.load(
|
|
query_ptr + query_offset,
|
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
|
other=0.0,
|
|
)
|
|
|
|
block_table_offset = seq_idx * block_table_stride
|
|
|
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
|
|
|
# sequence len for this particular sequence
|
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
|
|
|
# context length for this particular sequences
|
|
context_len = seq_len - cur_batch_query_len
|
|
|
|
# alibi slope for this head
|
|
if USE_ALIBI_SLOPES:
|
|
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
|
mask=query_mask_1,
|
|
other=0.0)
|
|
|
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
|
|
|
# iterate through tiles
|
|
for j in range(0, num_blocks):
|
|
|
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
|
|
|
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
|
|
v_offset = (physical_block_idx * stride_v_cache_0 +
|
|
kv_head_idx * stride_v_cache_2 +
|
|
offs_d[None, :] * stride_v_cache_3 +
|
|
offs_n[:, None] * stride_v_cache_1)
|
|
|
|
k_offset = (physical_block_idx * stride_k_cache_0 +
|
|
kv_head_idx * stride_k_cache_2 +
|
|
offs_d[:, None] * stride_k_cache_3 +
|
|
offs_n[None, :] * stride_k_cache_1)
|
|
|
|
# K : (HEAD_SIZE, BLOCK_SIZE)
|
|
K_load = tl.load(key_cache_ptr + k_offset,
|
|
mask=dim_mask[:, None],
|
|
other=0.0)
|
|
|
|
if K_load.dtype.is_fp8():
|
|
if Q.dtype.is_fp8():
|
|
K = K_load
|
|
else:
|
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
|
else:
|
|
K = K_load
|
|
|
|
# V : (BLOCK_SIZE, HEAD_SIZE)
|
|
V_load = tl.load(value_cache_ptr + v_offset,
|
|
mask=dim_mask[None, :],
|
|
other=0.0)
|
|
|
|
if V_load.dtype.is_fp8():
|
|
if Q.dtype.is_fp8():
|
|
V = V_load
|
|
else:
|
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
|
else:
|
|
V = V_load
|
|
|
|
seq_offset = j * BLOCK_SIZE + offs_n
|
|
|
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
|
|
|
# S : (BLOCK_M, BLOCK_SIZE)
|
|
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
|
|
|
S += scale * tl.dot(Q, K)
|
|
|
|
if USE_SOFTCAP:
|
|
S = apply_softcap(S, softcap)
|
|
|
|
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
|
S, float("-inf"))
|
|
|
|
if SLIDING_WINDOW > 0:
|
|
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
|
< SLIDING_WINDOW, S, float("-inf"))
|
|
|
|
if USE_ALIBI_SLOPES:
|
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
|
|
|
# compute running maximum
|
|
# m_j : (BLOCK_M,)
|
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
|
# For sliding window there's a chance the max is -inf due to masking of
|
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
|
|
|
# P : (BLOCK_M, BLOCK_SIZE)
|
|
P = tl.exp(S - m_j[:, None])
|
|
|
|
# l_j : (BLOCK_M,)
|
|
l_j = tl.sum(P, axis=1)
|
|
|
|
# alpha : (BLOCK_M, )
|
|
alpha = tl.exp(M - m_j)
|
|
|
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
acc = acc * alpha[:, None]
|
|
|
|
# update constants
|
|
L = L * alpha + l_j
|
|
M = m_j
|
|
|
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
acc += tl.dot(P.to(V.dtype), V)
|
|
|
|
# epilogue
|
|
acc = acc / L[:, None]
|
|
|
|
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
|
query_offset_1[:, None] * output_stride_1 +
|
|
offs_d[None, :])
|
|
|
|
tl.store(
|
|
output_ptr + output_offset,
|
|
acc,
|
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def kernel_unified_attention_3d(
|
|
segm_output_ptr,
|
|
# [num_tokens, num_query_heads, num_segments, head_size]
|
|
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
|
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
|
seq_lens_ptr, # [num_seqs]
|
|
alibi_slopes_ptr, # [num_query_heads]
|
|
scale, # float32
|
|
k_scale, # float32
|
|
v_scale, # float32
|
|
softcap, # float32
|
|
num_query_heads: tl.constexpr, # int
|
|
num_queries_per_kv: tl.constexpr, # int
|
|
block_table_stride: tl.int64, # int
|
|
query_stride_0: tl.int64, # int
|
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
|
BLOCK_SIZE: tl.constexpr, # int
|
|
HEAD_SIZE: tl.constexpr, # int
|
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
|
USE_SOFTCAP: tl.constexpr, # bool
|
|
SLIDING_WINDOW: tl.constexpr, # int
|
|
stride_k_cache_0: tl.int64, # int
|
|
stride_k_cache_1: tl.int64, # int
|
|
stride_k_cache_2: tl.int64, # int
|
|
stride_k_cache_3: tl.constexpr, # int
|
|
stride_v_cache_0: tl.int64, # int
|
|
stride_v_cache_1: tl.int64, # int
|
|
stride_v_cache_2: tl.int64, # int
|
|
stride_v_cache_3: tl.constexpr, # int
|
|
query_start_len_ptr, # [num_seqs+1]
|
|
BLOCK_Q: tl.constexpr, # int
|
|
num_seqs: tl.int32,
|
|
BLOCK_M: tl.constexpr, # int
|
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
|
):
|
|
q_block_global_idx = tl.program_id(0)
|
|
kv_head_idx = tl.program_id(1)
|
|
segm_idx = tl.program_id(2)
|
|
|
|
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
|
|
BLOCK_Q, True)
|
|
|
|
q_block_start_idx = tl.load(query_start_len_ptr +
|
|
seq_idx) // BLOCK_Q + seq_idx
|
|
|
|
q_block_local_idx = q_block_global_idx - q_block_start_idx
|
|
|
|
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
|
|
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
|
|
|
|
cur_batch_query_len = cur_batch_in_all_stop_index \
|
|
- cur_batch_in_all_start_index
|
|
|
|
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
|
return
|
|
|
|
# sequence len for this particular sequence
|
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
|
|
|
# number of segments for this particular sequence
|
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
|
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
|
|
|
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
|
|
return
|
|
|
|
offs_m = tl.arange(0, BLOCK_M)
|
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
|
|
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
|
|
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
|
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
|
offs_m % num_queries_per_kv
|
|
|
|
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
|
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
|
|
|
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
|
|
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
|
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
|
|
|
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
Q = tl.load(
|
|
query_ptr + query_offset,
|
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
|
other=0.0,
|
|
)
|
|
|
|
block_table_offset = seq_idx * block_table_stride
|
|
|
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
|
|
|
# context length for this particular sequences
|
|
context_len = seq_len - cur_batch_query_len
|
|
|
|
# alibi slope for this head
|
|
if USE_ALIBI_SLOPES:
|
|
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
|
|
mask=query_mask_1,
|
|
other=0.0)
|
|
|
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
|
|
|
# iterate through tiles within current segment
|
|
for j in range(
|
|
segm_idx * blocks_per_segment,
|
|
min((segm_idx + 1) * blocks_per_segment, num_blocks),
|
|
):
|
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
|
|
|
offs_n = tl.arange(0, BLOCK_SIZE)
|
|
|
|
v_offset = (physical_block_idx * stride_v_cache_0 +
|
|
kv_head_idx * stride_v_cache_2 +
|
|
offs_d[None, :] * stride_v_cache_3 +
|
|
offs_n[:, None] * stride_v_cache_1)
|
|
|
|
k_offset = (physical_block_idx * stride_k_cache_0 +
|
|
kv_head_idx * stride_k_cache_2 +
|
|
offs_d[:, None] * stride_k_cache_3 +
|
|
offs_n[None, :] * stride_k_cache_1)
|
|
|
|
# K : (HEAD_SIZE, BLOCK_SIZE)
|
|
K_load = tl.load(key_cache_ptr + k_offset,
|
|
mask=dim_mask[:, None],
|
|
other=0.0)
|
|
|
|
if K_load.dtype.is_fp8():
|
|
if Q.dtype.is_fp8():
|
|
K = K_load
|
|
else:
|
|
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
|
else:
|
|
K = K_load
|
|
|
|
# V : (BLOCK_SIZE, HEAD_SIZE)
|
|
V_load = tl.load(value_cache_ptr + v_offset,
|
|
mask=dim_mask[None, :],
|
|
other=0.0)
|
|
|
|
if V_load.dtype.is_fp8():
|
|
if Q.dtype.is_fp8():
|
|
V = V_load
|
|
else:
|
|
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
|
else:
|
|
V = V_load
|
|
|
|
seq_offset = j * BLOCK_SIZE + offs_n
|
|
|
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
|
|
|
# S : (BLOCK_M, BLOCK_SIZE)
|
|
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
|
|
|
S += scale * tl.dot(Q, K)
|
|
|
|
if USE_SOFTCAP:
|
|
S = apply_softcap(S, softcap)
|
|
|
|
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
|
|
S, float("-inf"))
|
|
|
|
if SLIDING_WINDOW > 0:
|
|
S = tl.where((context_len + query_pos[:, None] - seq_offset)
|
|
< SLIDING_WINDOW, S, float("-inf"))
|
|
|
|
if USE_ALIBI_SLOPES:
|
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
|
|
|
# compute running maximum
|
|
# m_j : (BLOCK_M,)
|
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
|
# For sliding window there's a chance the max is -inf due to masking of
|
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
|
|
|
# P : (BLOCK_M, BLOCK_SIZE,)
|
|
P = tl.exp(S - m_j[:, None])
|
|
|
|
# l_j : (BLOCK_M,)
|
|
l_j = tl.sum(P, axis=1)
|
|
|
|
# alpha : (BLOCK_M, )
|
|
alpha = tl.exp(M - m_j)
|
|
|
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
acc = acc * alpha[:, None]
|
|
|
|
# update constants
|
|
L = L * alpha + l_j
|
|
M = m_j
|
|
|
|
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
|
acc += tl.dot(P.to(V.dtype), V)
|
|
|
|
segm_output_offset = (
|
|
query_offset_0[:, None].to(tl.int64) *
|
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
|
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
|
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
|
tl.store(
|
|
segm_output_ptr + segm_output_offset,
|
|
acc,
|
|
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
|
)
|
|
segm_offset = (query_offset_0.to(tl.int64) *
|
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
|
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
|
|
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
|
|
tl.store(segm_expsum_ptr + segm_offset,
|
|
L,
|
|
mask=query_mask_0 & query_mask_1)
|
|
|
|
|
|
@triton.jit
|
|
def reduce_segments(
|
|
output_ptr, # [num_tokens, num_query_heads, head_size]
|
|
segm_output_ptr,
|
|
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
|
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
|
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
|
seq_lens_ptr, # [num_seqs]
|
|
num_seqs, # int
|
|
num_query_heads: tl.constexpr, # int
|
|
output_stride_0: tl.int64, # int
|
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
|
block_table_stride: tl.int64, # int
|
|
BLOCK_SIZE: tl.constexpr, # int
|
|
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
|
query_start_len_ptr, # [num_seqs+1]
|
|
BLOCK_Q: tl.constexpr, # int
|
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
|
):
|
|
query_token_idx = tl.program_id(0)
|
|
query_head_idx = tl.program_id(1)
|
|
|
|
seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs,
|
|
BLOCK_Q, False)
|
|
|
|
# sequence len for this particular sequence
|
|
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
|
|
|
# number of segments for this particular sequence
|
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
|
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
|
|
|
# create masks for subsequent loads
|
|
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
|
|
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
|
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
|
0).to(tl.int1)
|
|
|
|
# load segment maxima
|
|
segm_offset = (query_token_idx.to(tl.int64) *
|
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
|
|
query_head_idx * NUM_SEGMENTS_PER_SEQ +
|
|
tl.arange(0, NUM_SEGMENTS_PER_SEQ))
|
|
segm_max = tl.load(segm_max_ptr + segm_offset,
|
|
mask=segm_mask,
|
|
other=float("-inf"))
|
|
overall_max = tl.max(segm_max)
|
|
|
|
# load and rescale segment exp sums
|
|
segm_expsum = tl.load(segm_expsum_ptr + segm_offset,
|
|
mask=segm_mask,
|
|
other=0.0)
|
|
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
|
|
overall_expsum = tl.sum(segm_expsum)
|
|
|
|
# load, rescale, and add segment attention outputs
|
|
segm_output_offset = (
|
|
query_token_idx.to(tl.int64) *
|
|
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
|
query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
|
|
tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED +
|
|
tl.arange(0, HEAD_SIZE_PADDED)[None, :])
|
|
segm_output = tl.load(
|
|
segm_output_ptr + segm_output_offset,
|
|
mask=segm_mask[:, None] & dim_mask[None, :],
|
|
other=0.0,
|
|
)
|
|
segm_output *= tl.exp(segm_max - overall_max)[:, None]
|
|
acc_sum = tl.sum(segm_output, axis=0)
|
|
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
|
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
|
|
|
# write result
|
|
output_offset = (query_token_idx * output_stride_0 +
|
|
query_head_idx * output_stride_1 +
|
|
tl.arange(0, HEAD_SIZE_PADDED))
|
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
|
|
|
|
|
def unified_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
cu_seqlens_q,
|
|
max_seqlen_q,
|
|
seqused_k,
|
|
max_seqlen_k,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
block_table,
|
|
softcap,
|
|
q_descale,
|
|
k_descale,
|
|
v_descale,
|
|
alibi_slopes=None,
|
|
):
|
|
assert causal, "Only causal attention is supported"
|
|
assert q_descale is None, "Q scales not supported"
|
|
|
|
block_size = v.shape[1]
|
|
assert q.element_size() >= 2 or block_size >= 32, \
|
|
"Block size must be at least 32 for fp8"
|
|
|
|
use_alibi_slopes = alibi_slopes is not None
|
|
|
|
block_size = v.shape[1]
|
|
num_seqs = len(seqused_k)
|
|
num_query_heads = q.shape[1]
|
|
num_kv_heads = k.shape[2]
|
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
head_size = q.shape[2]
|
|
|
|
BLOCK_M = 16
|
|
BLOCK_Q = BLOCK_M // num_queries_per_kv
|
|
|
|
# Ideally we would launch with kernel with:
|
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
|
|
# However, it is slow to realize the query_lens on cpu.
|
|
# Instead we use upper-bound:
|
|
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
|
|
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
|
|
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
|
|
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
|
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
|
|
|
# if batch contains a prefill
|
|
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
|
kernel_unified_attention_2d[(
|
|
total_num_q_blocks,
|
|
num_kv_heads,
|
|
)](
|
|
output_ptr=out,
|
|
query_ptr=q,
|
|
key_cache_ptr=k,
|
|
value_cache_ptr=v,
|
|
block_tables_ptr=block_table,
|
|
seq_lens_ptr=seqused_k,
|
|
alibi_slopes_ptr=alibi_slopes,
|
|
scale=softmax_scale,
|
|
k_scale=k_descale,
|
|
v_scale=v_descale,
|
|
softcap=softcap,
|
|
num_query_heads=num_query_heads,
|
|
num_queries_per_kv=num_queries_per_kv,
|
|
block_table_stride=block_table.stride(0),
|
|
query_stride_0=q.stride(0),
|
|
query_stride_1=q.stride(1),
|
|
output_stride_0=out.stride(0),
|
|
output_stride_1=out.stride(1),
|
|
BLOCK_SIZE=block_size,
|
|
HEAD_SIZE=head_size,
|
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
|
USE_SOFTCAP=(softcap > 0),
|
|
SLIDING_WINDOW=(1 + window_size[0]),
|
|
stride_k_cache_0=k.stride(0),
|
|
stride_k_cache_1=k.stride(1),
|
|
stride_k_cache_2=k.stride(2),
|
|
stride_k_cache_3=k.stride(3),
|
|
stride_v_cache_0=v.stride(0),
|
|
stride_v_cache_1=v.stride(1),
|
|
stride_v_cache_2=v.stride(2),
|
|
stride_v_cache_3=v.stride(3),
|
|
query_start_len_ptr=cu_seqlens_q,
|
|
BLOCK_Q=BLOCK_Q,
|
|
num_seqs=num_seqs,
|
|
BLOCK_M=BLOCK_M,
|
|
)
|
|
else:
|
|
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
|
# value that showed good performance in tests
|
|
NUM_SEGMENTS = 16
|
|
|
|
segm_output = torch.empty(
|
|
q.shape[0],
|
|
num_query_heads,
|
|
NUM_SEGMENTS,
|
|
triton.next_power_of_2(head_size),
|
|
dtype=torch.float32,
|
|
device=q.device,
|
|
)
|
|
segm_max = torch.empty(
|
|
q.shape[0],
|
|
num_query_heads,
|
|
NUM_SEGMENTS,
|
|
dtype=torch.float32,
|
|
device=q.device,
|
|
)
|
|
segm_expsum = torch.empty(
|
|
q.shape[0],
|
|
num_query_heads,
|
|
NUM_SEGMENTS,
|
|
dtype=torch.float32,
|
|
device=q.device,
|
|
)
|
|
|
|
kernel_unified_attention_3d[(
|
|
total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
|
segm_output_ptr=segm_output,
|
|
segm_max_ptr=segm_max,
|
|
segm_expsum_ptr=segm_expsum,
|
|
query_ptr=q,
|
|
key_cache_ptr=k,
|
|
value_cache_ptr=v,
|
|
block_tables_ptr=block_table,
|
|
seq_lens_ptr=seqused_k,
|
|
alibi_slopes_ptr=alibi_slopes,
|
|
scale=softmax_scale,
|
|
k_scale=k_descale,
|
|
v_scale=v_descale,
|
|
softcap=softcap,
|
|
num_query_heads=num_query_heads,
|
|
num_queries_per_kv=num_queries_per_kv,
|
|
block_table_stride=block_table.stride(0),
|
|
query_stride_0=q.stride(0),
|
|
query_stride_1=q.stride(1),
|
|
BLOCK_SIZE=block_size,
|
|
HEAD_SIZE=head_size,
|
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
|
USE_SOFTCAP=(softcap > 0),
|
|
SLIDING_WINDOW=(1 + window_size[0]),
|
|
stride_k_cache_0=k.stride(0),
|
|
stride_k_cache_1=k.stride(1),
|
|
stride_k_cache_2=k.stride(2),
|
|
stride_k_cache_3=k.stride(3),
|
|
stride_v_cache_0=v.stride(0),
|
|
stride_v_cache_1=v.stride(1),
|
|
stride_v_cache_2=v.stride(2),
|
|
stride_v_cache_3=v.stride(3),
|
|
query_start_len_ptr=cu_seqlens_q,
|
|
BLOCK_Q=BLOCK_Q,
|
|
num_seqs=num_seqs,
|
|
BLOCK_M=BLOCK_M,
|
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
|
)
|
|
|
|
reduce_segments[(q.shape[0], num_query_heads)](
|
|
output_ptr=out,
|
|
segm_output_ptr=segm_output,
|
|
segm_max_ptr=segm_max,
|
|
segm_expsum_ptr=segm_expsum,
|
|
seq_lens_ptr=seqused_k,
|
|
num_seqs=num_seqs,
|
|
num_query_heads=num_query_heads,
|
|
output_stride_0=out.stride(0),
|
|
output_stride_1=out.stride(1),
|
|
block_table_stride=block_table.stride(0),
|
|
BLOCK_SIZE=block_size,
|
|
HEAD_SIZE=head_size,
|
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
|
query_start_len_ptr=cu_seqlens_q,
|
|
BLOCK_Q=BLOCK_Q,
|
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
|
)
|