Optimize triton swa kernel by skipping computation (#8860)
This commit is contained in:
@@ -0,0 +1,283 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton.testing as tt
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||||
|
|
||||||
|
|
||||||
|
def extend_attention_fwd_torch(
|
||||||
|
q: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||||
|
k: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||||
|
v: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||||
|
o: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||||
|
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||||
|
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||||
|
qo_indptr: torch.Tensor, # [B+1]
|
||||||
|
kv_indptr: torch.Tensor, # [B+1]
|
||||||
|
kv_indices: torch.Tensor, # [prefix_tokens]
|
||||||
|
sliding_window_size: int,
|
||||||
|
):
|
||||||
|
B = qo_indptr.size(0) - 1
|
||||||
|
_, H_Q, D = q.shape
|
||||||
|
_, H_KV, _ = k.shape
|
||||||
|
|
||||||
|
group_size = H_Q // H_KV
|
||||||
|
scale = 1.0 / D**0.5
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
q_start = int(qo_indptr[i].item())
|
||||||
|
q_end = int(qo_indptr[i + 1].item())
|
||||||
|
kv_start = int(kv_indptr[i].item())
|
||||||
|
kv_end = int(kv_indptr[i + 1].item())
|
||||||
|
|
||||||
|
prefix_indices = kv_indices[kv_start:kv_end]
|
||||||
|
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||||
|
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||||
|
|
||||||
|
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
|
||||||
|
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
|
||||||
|
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
|
||||||
|
|
||||||
|
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
|
||||||
|
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
|
||||||
|
|
||||||
|
if group_size != 1:
|
||||||
|
k_full_hq = k_full.repeat_interleave(
|
||||||
|
group_size, dim=1
|
||||||
|
) # [total_len, H_Q, D]
|
||||||
|
v_full_hq = v_full.repeat_interleave(
|
||||||
|
group_size, dim=1
|
||||||
|
) # [total_len, H_Q, D]
|
||||||
|
else:
|
||||||
|
k_full_hq = k_full
|
||||||
|
v_full_hq = v_full
|
||||||
|
|
||||||
|
prefix_len = k_prefix.size(0)
|
||||||
|
extend_len = k_extend.size(0)
|
||||||
|
total_len = prefix_len + extend_len
|
||||||
|
|
||||||
|
# causal
|
||||||
|
pos_keys = torch.arange(total_len, device=q.device)
|
||||||
|
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
|
||||||
|
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
|
||||||
|
|
||||||
|
# sliding window
|
||||||
|
if sliding_window_size is not None and sliding_window_size > 0:
|
||||||
|
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
|
||||||
|
else:
|
||||||
|
start = torch.zeros_like(t)
|
||||||
|
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
|
||||||
|
|
||||||
|
final_mask = causal_mask & window_mask
|
||||||
|
|
||||||
|
attn_scores = (
|
||||||
|
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
|
||||||
|
) # [extend_len, H_Q, total_len]
|
||||||
|
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||||
|
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_batch(
|
||||||
|
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
|
||||||
|
):
|
||||||
|
b_seq_len_prefix = torch.randint(
|
||||||
|
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
b_seq_len_extend = torch.randint(
|
||||||
|
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||||
|
|
||||||
|
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||||
|
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||||
|
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||||
|
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||||
|
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||||
|
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
for i in range(B):
|
||||||
|
s = kv_indptr[i].item()
|
||||||
|
e = kv_indptr[i + 1].item()
|
||||||
|
kv_indices[s:e] = torch.arange(
|
||||||
|
b_start_loc[i],
|
||||||
|
b_start_loc[i] + b_seq_len_prefix[i],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_token_num = int(torch.sum(b_seq_len).item())
|
||||||
|
extend_token_num = int(torch.sum(b_seq_len_extend).item())
|
||||||
|
|
||||||
|
k_buffer = torch.empty(
|
||||||
|
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||||
|
).normal_(mean=0.1, std=0.2)
|
||||||
|
v_buffer = torch.empty(
|
||||||
|
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||||
|
).normal_(mean=0.1, std=0.2)
|
||||||
|
|
||||||
|
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||||
|
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||||
|
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||||
|
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||||
|
extend_start = b_start_loc_extend[i]
|
||||||
|
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||||
|
|
||||||
|
k_extend[extend_start:extend_end] = k_buffer[
|
||||||
|
extend_start_in_buffer:extend_end_in_buffer
|
||||||
|
]
|
||||||
|
v_extend[extend_start:extend_end] = v_buffer[
|
||||||
|
extend_start_in_buffer:extend_end_in_buffer
|
||||||
|
]
|
||||||
|
q_extend[extend_start:extend_end] = torch.empty(
|
||||||
|
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
|
||||||
|
).normal_(mean=0.1, std=0.2)
|
||||||
|
|
||||||
|
o_extend_triton = torch.empty(
|
||||||
|
(extend_token_num, H_Q, D), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||||
|
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
|
||||||
|
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||||
|
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||||
|
|
||||||
|
inputs = dict(
|
||||||
|
q_extend=q_extend,
|
||||||
|
k_extend=k_extend,
|
||||||
|
v_extend=v_extend,
|
||||||
|
k_buffer=k_buffer,
|
||||||
|
v_buffer=v_buffer,
|
||||||
|
o_extend_triton=o_extend_triton,
|
||||||
|
o_extend_torch=o_extend_torch,
|
||||||
|
qo_indptr=qo_indptr,
|
||||||
|
kv_indptr=kv_indptr,
|
||||||
|
kv_indices=kv_indices,
|
||||||
|
max_len_extend=max_len_extend,
|
||||||
|
WINDOW_SIZE=WINDOW_SIZE,
|
||||||
|
)
|
||||||
|
meta = dict(
|
||||||
|
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
|
||||||
|
)
|
||||||
|
return inputs, meta
|
||||||
|
|
||||||
|
|
||||||
|
def _run_triton(inputs):
|
||||||
|
extend_attention_fwd(
|
||||||
|
inputs["q_extend"],
|
||||||
|
inputs["k_extend"],
|
||||||
|
inputs["v_extend"],
|
||||||
|
inputs["o_extend_triton"],
|
||||||
|
inputs["k_buffer"],
|
||||||
|
inputs["v_buffer"],
|
||||||
|
inputs["qo_indptr"],
|
||||||
|
inputs["kv_indptr"],
|
||||||
|
inputs["kv_indices"],
|
||||||
|
custom_mask=None,
|
||||||
|
is_causal=True,
|
||||||
|
mask_indptr=None,
|
||||||
|
max_len_extend=inputs["max_len_extend"],
|
||||||
|
sliding_window_size=inputs["WINDOW_SIZE"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_torch_ref(inputs):
|
||||||
|
extend_attention_fwd_torch(
|
||||||
|
inputs["q_extend"],
|
||||||
|
inputs["k_extend"],
|
||||||
|
inputs["v_extend"],
|
||||||
|
inputs["o_extend_torch"],
|
||||||
|
inputs["k_buffer"],
|
||||||
|
inputs["v_buffer"],
|
||||||
|
inputs["qo_indptr"],
|
||||||
|
inputs["kv_indptr"],
|
||||||
|
inputs["kv_indices"],
|
||||||
|
inputs["WINDOW_SIZE"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
N_CTXS = [1024, 2048, 4096, 8192]
|
||||||
|
WINDOW_SIZES = [-1, 127, 256, 512]
|
||||||
|
|
||||||
|
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
|
||||||
|
|
||||||
|
PROVIDERS = ["torch", "triton"]
|
||||||
|
|
||||||
|
|
||||||
|
@tt.perf_report(
|
||||||
|
tt.Benchmark(
|
||||||
|
x_names=["N_CTX", "WINDOW_SIZE"],
|
||||||
|
x_vals=CONFIGS,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=PROVIDERS,
|
||||||
|
line_names=PROVIDERS,
|
||||||
|
ylabel="Runtime (ms)",
|
||||||
|
plot_name="extend_attention_triton_vs_torch",
|
||||||
|
args={
|
||||||
|
"B": 32,
|
||||||
|
"H_Q": 64,
|
||||||
|
"H_KV": 8,
|
||||||
|
"D": 128,
|
||||||
|
"dtype": "bf16",
|
||||||
|
"device": "cuda",
|
||||||
|
"check_correctness": False,
|
||||||
|
"warmup": 25,
|
||||||
|
"rep": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def bench(
|
||||||
|
N_CTX,
|
||||||
|
provider,
|
||||||
|
B,
|
||||||
|
H_Q,
|
||||||
|
H_KV,
|
||||||
|
D,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
WINDOW_SIZE,
|
||||||
|
check_correctness,
|
||||||
|
warmup,
|
||||||
|
rep,
|
||||||
|
):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||||
|
dt = dtype_map[dtype]
|
||||||
|
|
||||||
|
inputs, _ = _build_batch(
|
||||||
|
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_correctness and provider == "triton":
|
||||||
|
_run_triton(inputs)
|
||||||
|
_run_torch_ref(inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if not torch.allclose(
|
||||||
|
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
|
||||||
|
):
|
||||||
|
raise AssertionError("Mismatch between triton and torch reference.")
|
||||||
|
|
||||||
|
if provider == "triton":
|
||||||
|
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
|
||||||
|
elif provider == "torch":
|
||||||
|
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
|
||||||
|
else:
|
||||||
|
raise ValueError(provider)
|
||||||
|
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bench.run(print_data=True, show_plots=False)
|
||||||
@@ -134,38 +134,6 @@ def _fwd_kernel(
|
|||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||||
|
|
||||||
offs_kv_loc = tl.load(
|
|
||||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# load k in transposed way
|
|
||||||
offs_buf_k = (
|
|
||||||
offs_kv_loc[None, :] * stride_buf_kbs
|
|
||||||
+ cur_kv_head * stride_buf_kh
|
|
||||||
+ offs_d[:, None]
|
|
||||||
)
|
|
||||||
k = tl.load(
|
|
||||||
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
qk = tl.dot(q.to(k.dtype), k)
|
|
||||||
if BLOCK_DPE > 0:
|
|
||||||
offs_kpe = (
|
|
||||||
offs_kv_loc[None, :] * stride_buf_kbs
|
|
||||||
+ cur_kv_head * stride_buf_kh
|
|
||||||
+ offs_dpe[:, None]
|
|
||||||
)
|
|
||||||
kpe = tl.load(
|
|
||||||
K_Buffer + offs_kpe,
|
|
||||||
mask=mask_n[None, :],
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
|
||||||
qk *= sm_scale
|
|
||||||
|
|
||||||
if logit_cap > 0:
|
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
|
||||||
|
|
||||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||||
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
@@ -185,28 +153,72 @@ def _fwd_kernel(
|
|||||||
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
||||||
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
||||||
final_mask &= window_mask
|
final_mask &= window_mask
|
||||||
qk = tl.where(final_mask, qk, float("-inf"))
|
|
||||||
|
|
||||||
row_max = tl.max(qk, 1)
|
SKIP_TILE = False
|
||||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
|
||||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
||||||
|
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
if not SKIP_TILE:
|
||||||
p = tl.exp(qk - n_e_max[:, None])
|
offs_kv_loc = tl.load(
|
||||||
deno = deno * re_scale + tl.sum(p, 1)
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
|
||||||
|
mask=mask_n,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
|
||||||
offs_buf_v = (
|
# load k in transposed way
|
||||||
offs_kv_loc[:, None] * stride_buf_vbs
|
offs_buf_k = (
|
||||||
+ cur_kv_head * stride_buf_vh
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
+ offs_dv[None, :]
|
+ cur_kv_head * stride_buf_kh
|
||||||
)
|
+ offs_d[:, None]
|
||||||
v = tl.load(
|
)
|
||||||
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
k = tl.load(
|
||||||
)
|
K_Buffer + offs_buf_k,
|
||||||
p = p.to(v.dtype)
|
mask=(mask_n[None, :]) & (mask_d[:, None]),
|
||||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
e_max = n_e_max
|
qk = tl.dot(q.to(k.dtype), k)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_kpe = (
|
||||||
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
|
row_max = tl.max(qk, 1)
|
||||||
|
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||||
|
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||||
|
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
|
|
||||||
|
offs_buf_v = (
|
||||||
|
offs_kv_loc[:, None] * stride_buf_vbs
|
||||||
|
+ cur_kv_head * stride_buf_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v,
|
||||||
|
mask=mask_n[:, None] & mask_dv[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
# stage 2: compute the triangle part
|
# stage 2: compute the triangle part
|
||||||
|
|
||||||
@@ -219,35 +231,6 @@ def _fwd_kernel(
|
|||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||||
|
|
||||||
# load k in transposed way
|
|
||||||
offs_k = (
|
|
||||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
|
||||||
+ cur_kv_head * stride_kh
|
|
||||||
+ offs_d[:, None]
|
|
||||||
)
|
|
||||||
k = tl.load(
|
|
||||||
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
|
||||||
if BLOCK_DPE > 0:
|
|
||||||
offs_kpe = (
|
|
||||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
|
||||||
+ cur_kv_head * stride_kh
|
|
||||||
+ offs_dpe[:, None]
|
|
||||||
)
|
|
||||||
kpe = tl.load(
|
|
||||||
K_Extend + offs_kpe,
|
|
||||||
mask=mask_n[None, :],
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
qk += tl.dot(qpe, kpe)
|
|
||||||
|
|
||||||
qk *= sm_scale
|
|
||||||
|
|
||||||
if logit_cap > 0:
|
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
|
||||||
|
|
||||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||||
if USE_CUSTOM_MASK:
|
if USE_CUSTOM_MASK:
|
||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
@@ -279,28 +262,62 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
final_mask &= window_mask
|
final_mask &= window_mask
|
||||||
|
|
||||||
qk = tl.where(final_mask, qk, float("-inf"))
|
SKIP_TILE = False
|
||||||
|
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
|
||||||
|
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
||||||
|
|
||||||
row_max = tl.max(qk, 1)
|
if not SKIP_TILE:
|
||||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
# load k in transposed way
|
||||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
offs_k = (
|
||||||
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_d[:, None]
|
||||||
|
)
|
||||||
|
k = tl.load(
|
||||||
|
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||||
p = tl.exp(qk - n_e_max[:, None])
|
if BLOCK_DPE > 0:
|
||||||
deno = deno * re_scale + tl.sum(p, 1)
|
offs_kpe = (
|
||||||
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Extend + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe)
|
||||||
|
|
||||||
offs_v = (
|
qk *= sm_scale
|
||||||
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
|
||||||
+ cur_kv_head * stride_vh
|
|
||||||
+ offs_dv[None, :]
|
|
||||||
)
|
|
||||||
v = tl.load(
|
|
||||||
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
|
||||||
)
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
|
||||||
|
|
||||||
e_max = n_e_max
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
|
row_max = tl.max(qk, 1)
|
||||||
|
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||||
|
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||||
|
|
||||||
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
|
deno = deno * re_scale + tl.sum(p, 1)
|
||||||
|
|
||||||
|
offs_v = (
|
||||||
|
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
||||||
|
+ cur_kv_head * stride_vh
|
||||||
|
+ offs_dv[None, :]
|
||||||
|
)
|
||||||
|
v = tl.load(
|
||||||
|
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||||
|
)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
|
e_max = n_e_max
|
||||||
|
|
||||||
if HAS_SINK:
|
if HAS_SINK:
|
||||||
cur_sink = tl.load(sink_ptr + cur_head)
|
cur_sink = tl.load(sink_ptr + cur_head)
|
||||||
|
|||||||
Reference in New Issue
Block a user