[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -0,0 +1,433 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v, # (#tokens, n_heads, head_size)
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q,
|
||||
sm_scale,
|
||||
sparse_layout,
|
||||
*,
|
||||
block_size=64,
|
||||
q_block_size=None,
|
||||
max_seqlen=None):
|
||||
# split q to blocks
|
||||
|
||||
assert isinstance(sparse_layout, (list, tuple))
|
||||
|
||||
_, n_heads, head_size = q.shape
|
||||
batch_size = cu_seqlens_k.size(0) - 1
|
||||
q_block_size = q_block_size or block_size
|
||||
|
||||
assert q.dim() == k.dim() == v.dim() == 3
|
||||
assert q.size(1) % k.size(1) == 0
|
||||
assert q.size(2) == k.size(2)
|
||||
# TODO(linxihui): allow k, v to have different head_size
|
||||
assert k.shape == v.shape
|
||||
assert cu_seqlens_k.dim() == 1
|
||||
|
||||
q_k_ratio = q.size(1) // k.size(1)
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
if q.size(0) == batch_size: # decoding only
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size + 1,
|
||||
dtype=cu_seqlens_k.dtype,
|
||||
device=cu_seqlens_k.device,
|
||||
)
|
||||
elif q.size(0) == k.size(0):
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
else:
|
||||
raise ValueError("cu_seqlens_q must be specified\
|
||||
if it mix of prefilling and decoding.")
|
||||
else:
|
||||
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||
|
||||
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||
|
||||
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||
|
||||
if max_seqlen:
|
||||
assert k_lens.max() <= max_seqlen
|
||||
|
||||
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||
|
||||
q_batch_ids = torch.tensor(
|
||||
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
q_start_sids = torch.tensor(
|
||||
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||
dtype=cu_seqlens_q.dtype,
|
||||
device=cu_seqlens_q.device,
|
||||
)
|
||||
|
||||
out = q.new_empty(q.shape)
|
||||
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||
|
||||
layout_crow_indices, layout_col_indices = sparse_layout
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
|
||||
decoding_only = (q_lens == 1).all().item()
|
||||
grid = (len(q_start_sids), n_heads, 1)
|
||||
|
||||
_fwd_kernel_batch_inference[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
sm_scale,
|
||||
cu_seqlens_q[:-1],
|
||||
cu_seqlens_q[1:],
|
||||
cu_seqlens_k[:-1],
|
||||
cu_seqlens_k[1:],
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
0,
|
||||
*q.stride(),
|
||||
0,
|
||||
*k.stride(),
|
||||
0,
|
||||
*v.stride(),
|
||||
0,
|
||||
*out.stride(),
|
||||
layout_crow_indices,
|
||||
layout_col_indices,
|
||||
*layout_crow_indices.stride(),
|
||||
*layout_col_indices.stride(),
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM=False,
|
||||
D_HEAD=head_size,
|
||||
BLOCK_M=q_block_size,
|
||||
BLOCK_N=block_size,
|
||||
BLOCK_D=block_d,
|
||||
BLOCK_M_LOADING=(16 if decoding_only else
|
||||
q_block_size), # smaller for decoding
|
||||
EVEN_D=block_d == head_size,
|
||||
num_warps=1 if decoding_only else 4,
|
||||
num_stages=3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
LAST_K_BLOCK: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt,
|
||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||
(offs_d[:, None] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||
mask=offs_d[:, None] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||
if LAST_K_BLOCK | M_LT_N:
|
||||
qk += tl.where(
|
||||
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# flash-attn2
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
# update m_i
|
||||
m_i = m_ij
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if LAST_K_BLOCK:
|
||||
if EVEN_D:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt,
|
||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||
(offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||
mask=offs_d[None, :] < D_HEAD,
|
||||
other=0.0)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"M_LT_N":
|
||||
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||
})
|
||||
@triton.jit
|
||||
def _fwd_kernel_batch_inference(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
sm_scale,
|
||||
q_batch_starts,
|
||||
q_batch_ends,
|
||||
k_batch_starts,
|
||||
k_batch_ends,
|
||||
q_batch_ids,
|
||||
q_start_sids,
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
layout_crow_ptr,
|
||||
layout_col_ptr,
|
||||
layout_crow_stride_h,
|
||||
layout_crow_stride_m,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
q_k_ratio,
|
||||
HAS_BATCH_DIM: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
BLOCK_M_LOADING: tl.constexpr,
|
||||
EVEN_D: tl.constexpr,
|
||||
M_LT_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
NOTATION:
|
||||
pid: position id
|
||||
sid: storage id
|
||||
sbid: storage block id
|
||||
pbid: position block id
|
||||
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
|
||||
TODO(linxihui):
|
||||
Optimize grouped-attn
|
||||
"""
|
||||
off_zm = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
off_z = tl.program_id(2)
|
||||
Q += off_z * stride_qb
|
||||
K += off_z * stride_kb
|
||||
V += off_z * stride_vb
|
||||
Out += off_z * stride_ob
|
||||
start_m = off_zm
|
||||
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||
else:
|
||||
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||
q_pbid * layout_crow_stride_m)
|
||||
|
||||
# TODO(linxihui): load at once, with any Triton version
|
||||
# that supports `tl.split`, e.g., Triton 3.0
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||
|
||||
sm_scale *= (
|
||||
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||
)
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_col_idx,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
False,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
acc, l_i, m_i = _fwd_kernel_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
Q,
|
||||
k_block_end - 1,
|
||||
layout_col_ptr,
|
||||
layout_col_stride_h,
|
||||
layout_col_stride_m,
|
||||
k_ptrs,
|
||||
v_ptrs,
|
||||
off_h,
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_d,
|
||||
stride_kt,
|
||||
stride_vt,
|
||||
sm_scale,
|
||||
k_seqlen,
|
||||
past_len,
|
||||
True,
|
||||
BLOCK_M_LOADING,
|
||||
BLOCK_N,
|
||||
D_HEAD,
|
||||
EVEN_D,
|
||||
M_LT_N,
|
||||
)
|
||||
|
||||
# flash-attn 2
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=offs_m[:, None] < q_seqlen,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
||||
Reference in New Issue
Block a user