[triton] Remove the zero initialization of qk_acc by directly writing the result (#1288)

This commit is contained in:
Byron Hsu
2024-09-01 03:12:06 -07:00
committed by GitHub
parent 6cb32ef92c
commit 00b19f198f

View File

@@ -127,8 +127,7 @@ def _fwd_kernel(
)
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q.to(k.dtype), k)
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
@@ -179,9 +178,7 @@ def _fwd_kernel(
)
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.dot(q, k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])