diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 31a002f43..6c7686971 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -127,8 +127,7 @@ def _fwd_kernel( ) k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q.to(k.dtype), k) + qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( offs_kv_loc[None, :] * stride_buf_kbs @@ -179,9 +178,7 @@ def _fwd_kernel( ) k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - + qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])