optimize MiniMax-Text-01 lightning_attn_decode triton (#2966)
This commit is contained in:
@@ -23,7 +23,10 @@ def _decode_kernel(
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr = 32,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
@@ -39,21 +42,38 @@ def _decode_kernel(
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
q = tl.load(Q + qk_offset + d_idx)
|
||||
k = tl.load(K + qk_offset + d_idx)
|
||||
v = tl.load(V + v_offset + e_idx)
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
kv = tl.load(KV + kv_offset + d_idx[:, None] * e + e_idx[None, :])
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty)
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty))
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def lightning_attn_decode(q, k, v, kv, s):
|
||||
@@ -62,26 +82,27 @@ def lightning_attn_decode(q, k, v, kv, s):
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Pad dimensions to power of 2
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Pad inputs
|
||||
q_padded = F.pad(q, (0, d_padded - d))
|
||||
k_padded = F.pad(k, (0, d_padded - d))
|
||||
v_padded = F.pad(v, (0, e_padded - e))
|
||||
kv_padded = F.pad(kv, (0, e_padded - e, 0, d_padded - d))
|
||||
|
||||
# Ensure inputs are contiguous
|
||||
q_padded = q_padded.contiguous()
|
||||
k_padded = k_padded.contiguous()
|
||||
v_padded = v_padded.contiguous()
|
||||
kv_padded = kv_padded.contiguous().to(torch.float32)
|
||||
s = s.contiguous()
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
@@ -95,10 +116,12 @@ def lightning_attn_decode(q, k, v, kv, s):
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Remove padding
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
@@ -351,6 +374,8 @@ def test_lightning_attention_implementations(model_params):
|
||||
msg="Lightning attention implementations produce different kv results",
|
||||
)
|
||||
|
||||
print("✅ Two implementations match")
|
||||
|
||||
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
def get_slopes(n):
|
||||
@@ -375,7 +400,7 @@ def _build_slope_tensor(n_attention_heads: int):
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_size_range = [2**i for i in range(0, 12)] # max 2048
|
||||
batch_size_range = [i for i in range(1, 33)] # max 32
|
||||
seq_length_range = [1] # decode mode sequence length is fixed to 1
|
||||
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user