[triton] Remove the zero initialization of qk_acc by directly writing the result (#1288)
This commit is contained in:
@@ -127,8 +127,7 @@ def _fwd_kernel(
|
||||
)
|
||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q.to(k.dtype), k)
|
||||
qk = tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
@@ -179,9 +178,7 @@ def _fwd_kernel(
|
||||
)
|
||||
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
|
||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||
|
||||
Reference in New Issue
Block a user