Scale kkt after reduction (#10604)
This commit is contained in:
@@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
(1, 0),
|
(1, 0),
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_kb = b_k * b_beta[:, None]
|
b_A += tl.dot(b_k, tl.trans(b_k))
|
||||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
|
||||||
|
|
||||||
if USE_G:
|
if USE_G:
|
||||||
p_g = tl.make_block_ptr(
|
p_g = tl.make_block_ptr(
|
||||||
@@ -85,6 +84,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
|||||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||||
b_A = b_A * safe_exp(b_g_diff)
|
b_A = b_A * safe_exp(b_g_diff)
|
||||||
|
|
||||||
|
b_A *= b_beta[:, None]
|
||||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||||
p_A = tl.make_block_ptr(
|
p_A = tl.make_block_ptr(
|
||||||
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user