Files
sglang/python/sglang/srt/layers/extend_attention.py
2024-05-24 03:48:53 -07:00

436 lines
13 KiB
Python

import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel(
Q_Extend,
K_Extend,
V_Extend,
O_Extend,
K_Buffer,
V_Buffer,
Req_to_tokens,
B_req_idx,
B_Seq_Len,
B_Start_Loc_Extend,
B_Seq_Len_Extend,
sm_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
cur_seq_prefix_start_in_loc = 0
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
cur_seq_prefix_start_in_loc + start_n + offs_n
)
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
# load k in transposed way
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_d[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
# stage2: compute the trianlge part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
# load k in transposed way
offs_k = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :]
)
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
offs_o = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
logit_cap=-1,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
Lq, Lk, Lv, Lo = (
q_extend.shape[-1],
k_extend.shape[-1],
v_extend.shape[-1],
o_extend.shape[-1],
)
assert Lq == Lk and Lk == Lv and Lv == Lo
assert Lq in {16, 32, 64, 128, 256}
if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
sm_scale = 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
BLOCK_DMODEL=Lq,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
logit_cap=logit_cap,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
):
total_token_num = k_buffer.shape[0]
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
q_buffer = torch.empty(
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
pt += cur_seq_len_extend
o_buffer = torch.empty_like(q_buffer)
context_attention_fwd(
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
pt += cur_seq_len_extend
def test():
torch.manual_seed(0)
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
dtype = torch.float16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
if __name__ == "__main__":
test()