Optimize triton swa kernel by skipping computation (#8860)
This commit is contained in:
@@ -134,38 +134,6 @@ def _fwd_kernel(
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
||||
)
|
||||
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||
)
|
||||
|
||||
qk = tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
||||
custom_mask = tl.load(
|
||||
@@ -185,28 +153,72 @@ def _fwd_kernel(
|
||||
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
||||
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
||||
final_mask &= window_mask
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
SKIP_TILE = False
|
||||
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
|
||||
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
if not SKIP_TILE:
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
|
||||
mask=mask_n,
|
||||
other=0,
|
||||
)
|
||||
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(mask_n[None, :]) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
e_max = n_e_max
|
||||
qk = tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=mask_n[:, None] & mask_dv[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# stage 2: compute the triangle part
|
||||
|
||||
@@ -219,35 +231,6 @@ def _fwd_kernel(
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||
)
|
||||
|
||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Extend + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||
if USE_CUSTOM_MASK:
|
||||
custom_mask = tl.load(
|
||||
@@ -279,28 +262,62 @@ def _fwd_kernel(
|
||||
)
|
||||
final_mask &= window_mask
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
SKIP_TILE = False
|
||||
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
|
||||
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
if not SKIP_TILE:
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||
)
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Extend + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
qk *= sm_scale
|
||||
|
||||
e_max = n_e_max
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
if HAS_SINK:
|
||||
cur_sink = tl.load(sink_ptr + cur_head)
|
||||
|
||||
Reference in New Issue
Block a user