Scale kkt after reduction (#10604)
This commit is contained in:
@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
b_A += tl.dot(b_k, tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(
|
||||
@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * safe_exp(b_g_diff)
|
||||
|
||||
b_A *= b_beta[:, None]
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
|
||||
Reference in New Issue
Block a user