From db71c38fcd7356af3cac60eb8639db540d980daa Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Thu, 18 Sep 2025 20:51:40 +0800 Subject: [PATCH] Scale kkt after reduction (#10604) --- .../sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py b/python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py index 699350d31..7a25e68c4 100644 --- a/python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +++ b/python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py @@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( (1, 0), ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_kb = b_k * b_beta[:, None] - b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + b_A += tl.dot(b_k, tl.trans(b_k)) if USE_G: p_g = tl.make_block_ptr( @@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( b_g_diff = b_g[:, None] - b_g[None, :] b_A = b_A * safe_exp(b_g_diff) + b_A *= b_beta[:, None] b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) p_A = tl.make_block_ptr( A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)