From 00b19f198f198bd2f7182596773d80f5217ab757 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 1 Sep 2024 03:12:06 -0700 Subject: [PATCH] [triton] Remove the zero initialization of qk_acc by directly writing the result (#1288) --- python/sglang/srt/layers/extend_attention.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 31a002f43..6c7686971 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -127,8 +127,7 @@ def _fwd_kernel( ) k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q.to(k.dtype), k) + qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( offs_kv_loc[None, :] * stride_buf_kbs @@ -179,9 +178,7 @@ def _fwd_kernel( ) k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - + qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])