optimize MiniMax-Text-01 lightning_attn_decode triton (#2966)
This commit is contained in:
@@ -23,7 +23,10 @@ def _decode_kernel(
|
|||||||
h: tl.constexpr,
|
h: tl.constexpr,
|
||||||
n: tl.constexpr,
|
n: tl.constexpr,
|
||||||
d: tl.constexpr,
|
d: tl.constexpr,
|
||||||
|
d_original: tl.constexpr,
|
||||||
e: tl.constexpr,
|
e: tl.constexpr,
|
||||||
|
e_original: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr = 32,
|
||||||
):
|
):
|
||||||
off_bh = tl.program_id(0)
|
off_bh = tl.program_id(0)
|
||||||
off_h = off_bh % h
|
off_h = off_bh % h
|
||||||
@@ -39,21 +42,38 @@ def _decode_kernel(
|
|||||||
d_idx = tl.arange(0, d)
|
d_idx = tl.arange(0, d)
|
||||||
e_idx = tl.arange(0, e)
|
e_idx = tl.arange(0, e)
|
||||||
|
|
||||||
q = tl.load(Q + qk_offset + d_idx)
|
# Create masks for original dimensions
|
||||||
k = tl.load(K + qk_offset + d_idx)
|
d_mask = d_idx < d_original
|
||||||
v = tl.load(V + v_offset + e_idx)
|
e_mask = e_idx < e_original
|
||||||
|
|
||||||
kv = tl.load(KV + kv_offset + d_idx[:, None] * e + e_idx[None, :])
|
# Load with masking
|
||||||
|
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||||
|
|
||||||
|
# Load KV with 2D masking
|
||||||
|
kv = tl.load(
|
||||||
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute outer product using element-wise operations
|
||||||
k_v_prod = k[:, None] * v[None, :]
|
k_v_prod = k[:, None] * v[None, :]
|
||||||
kv = ratio * kv + k_v_prod
|
kv = ratio * kv + k_v_prod
|
||||||
|
|
||||||
|
# Store KV with 2D masking
|
||||||
tl.store(
|
tl.store(
|
||||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty)
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
kv.to(KV.dtype.element_ty),
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||||
o = tl.sum(q[:, None] * kv, axis=0)
|
o = tl.sum(q[:, None] * kv, axis=0)
|
||||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty))
|
|
||||||
|
# Store output with masking
|
||||||
|
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||||
|
|
||||||
|
|
||||||
def lightning_attn_decode(q, k, v, kv, s):
|
def lightning_attn_decode(q, k, v, kv, s):
|
||||||
@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s):
|
|||||||
e = v.shape[-1]
|
e = v.shape[-1]
|
||||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||||
|
|
||||||
# Pad dimensions to power of 2
|
# Get padded dimensions (power of 2)
|
||||||
d_padded = next_power_of_2(d)
|
d_padded = next_power_of_2(d)
|
||||||
e_padded = next_power_of_2(e)
|
e_padded = next_power_of_2(e)
|
||||||
|
|
||||||
# Pad inputs
|
|
||||||
q_padded = F.pad(q, (0, d_padded - d))
|
|
||||||
k_padded = F.pad(k, (0, d_padded - d))
|
|
||||||
v_padded = F.pad(v, (0, e_padded - e))
|
|
||||||
kv_padded = F.pad(kv, (0, e_padded - e, 0, d_padded - d))
|
|
||||||
|
|
||||||
# Ensure inputs are contiguous
|
|
||||||
q_padded = q_padded.contiguous()
|
|
||||||
k_padded = k_padded.contiguous()
|
|
||||||
v_padded = v_padded.contiguous()
|
|
||||||
kv_padded = kv_padded.contiguous().to(torch.float32)
|
|
||||||
s = s.contiguous()
|
|
||||||
|
|
||||||
# Create output tensor (padded)
|
# Create output tensor (padded)
|
||||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
|
||||||
|
# Create padded tensors without actually padding the data
|
||||||
|
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||||
|
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||||
|
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
kv_padded = torch.empty(
|
||||||
|
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data to padded tensors
|
||||||
|
q_padded[..., :d] = q
|
||||||
|
k_padded[..., :d] = k
|
||||||
|
v_padded[..., :e] = v
|
||||||
|
kv_padded[..., :d, :e] = kv
|
||||||
|
|
||||||
# Launch kernel
|
# Launch kernel
|
||||||
grid = (b * h, 1)
|
grid = (b * h, 1)
|
||||||
_decode_kernel[grid](
|
_decode_kernel[grid](
|
||||||
@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s):
|
|||||||
h=h,
|
h=h,
|
||||||
n=n,
|
n=n,
|
||||||
d=d_padded,
|
d=d_padded,
|
||||||
|
d_original=d,
|
||||||
e=e_padded,
|
e=e_padded,
|
||||||
|
e_original=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove padding
|
# Get unpadded outputs
|
||||||
o = o_padded[..., :e]
|
o = o_padded[..., :e]
|
||||||
kv_out = kv_padded[..., :d, :e]
|
kv_out = kv_padded[..., :d, :e]
|
||||||
|
|
||||||
@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params):
|
|||||||
msg="Lightning attention implementations produce different kv results",
|
msg="Lightning attention implementations produce different kv results",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("✅ Two implementations match")
|
||||||
|
|
||||||
|
|
||||||
def _build_slope_tensor(n_attention_heads: int):
|
def _build_slope_tensor(n_attention_heads: int):
|
||||||
def get_slopes(n):
|
def get_slopes(n):
|
||||||
@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int):
|
|||||||
|
|
||||||
|
|
||||||
def get_benchmark():
|
def get_benchmark():
|
||||||
batch_size_range = [2**i for i in range(0, 12)] # max 2048
|
batch_size_range = [i for i in range(1, 33)] # max 32
|
||||||
seq_length_range = [1] # decode mode sequence length is fixed to 1
|
seq_length_range = [1] # decode mode sequence length is fixed to 1
|
||||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user