[triton] Remove the zero initialization of qk_acc by directly writing the result (#1288)
This commit is contained in:
@@ -127,8 +127,7 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.dot(q.to(k.dtype), k)
|
||||||
qk += tl.dot(q.to(k.dtype), k)
|
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_kpe = (
|
offs_kpe = (
|
||||||
offs_kv_loc[None, :] * stride_buf_kbs
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
@@ -179,9 +178,7 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
|
||||||
|
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_kpe = (
|
offs_kpe = (
|
||||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||||
|
|||||||
Reference in New Issue
Block a user