From ab31793661956a448bf4f47098f0f3907c4841e1 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:18:29 +0800 Subject: [PATCH] [kernel] MiniMax-Text-01 prefill lightning_attn with triton (#2911) --- .../benchmark_lighting_attention_prefill.py | 601 ++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100644 benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py diff --git a/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py b/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py new file mode 100644 index 000000000..3db4694c7 --- /dev/null +++ b/benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_prefill.py @@ -0,0 +1,601 @@ +import itertools +import math +import os +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + + +# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Out, + S, # log lambda + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK: tl.constexpr, + BLOCK_MODEL: tl.constexpr, +): + ##### get offset + off_bh = tl.program_id(0) + off_h = off_bh % h + off_e = tl.program_id(1) + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + # channel offset + e_offset = off_e * BLOCK_MODEL + + ##### get block ptr + Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :] + K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None] + V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + S_block_ptr = S + off_h + + ##### init diag decay(Lambda); q, k decay; kv + s = tl.load(S_block_ptr) + # q, k decay + off_block = tl.arange( + 0, BLOCK + ) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent + q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None]) + k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :])) + block_decay = tl.exp(-s.to(tl.float32) * BLOCK) + # diag decay + index = off_block[:, None] - off_block[None, :] + s_index = s * index + s_index = tl.where(index >= 0, -s_index, float("-inf")) + diag_decay = tl.exp(s_index) + kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32) + + ##### compute + for i in range(NUM_BLOCK): + # load + q = tl.load( + Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + k_trans = tl.load( + K_trans_block_ptr + off_block[None, :] * d, + mask=off_block[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + + # compute + qk = tl.dot(q, k_trans) * diag_decay + o_intra = tl.dot(qk, v) + o_inter = tl.dot(q, kv) * q_decay + o = o_intra + o_inter + + # save and update + tl.store( + O_block_ptr + off_block[:, None] * e, + o.to(O_block_ptr.dtype.element_ty), + mask=off_block[:, None] < n, + ) + kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v) + off_block += BLOCK + + +def lightning_attn2(q, k, v, s): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + b, h, n, d = q.shape + e = v.shape[-1] + + # Pad d to next power of 2 + d_padded = next_power_of_2(d) + if d_padded != d: + q_padded = F.pad(q, (0, d_padded - d)) + k_padded = F.pad(k, (0, d_padded - d)) + else: + q_padded = q + k_padded = k + + # Pad e to next power of 2 + e_padded = next_power_of_2(e) + if e_padded != e: + v_padded = F.pad(v, (0, e_padded - e)) + else: + v_padded = v + + o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device) + + BLOCK = 64 + NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK) + # parallel over channel + BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32) + grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL)) + + _fwd_kernel[grid]( + q_padded, + k_padded, + v_padded, + o_padded, + s, + b, + h, + n, + d_padded, + e_padded, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + BLOCK_MODEL=BLOCK_MODEL, + ) + + # Remove padding from output + if e_padded != e: + o = o_padded[..., :e] + else: + o = o_padded + + return o + + +def is_support(dim): + return 16 % dim + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +def lightning_attn_func(q, k, v, s): + b, h, n, d = q.shape + e = v.shape[-1] + assert is_support(d) and is_support(e) + + # pad v's feature dim to power of 2 + e_pad = next_power_of_2(e) + need_pad = e_pad != e + if need_pad: + v = F.pad(v, (0, e_pad - e)) + + if d > 128: + # split over head + if 64 % d: + m = 64 + elif 32 % d: + m = 32 + elif 16 % d: + m = 16 + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + o = 0 + for i in range(n - 1): + start = arr[i] + end = arr[i + 1] + q1 = q[..., start:end] + k1 = k[..., start:end] + o += lightning_attn2(q1, k1, v, s) + else: + o = lightning_attn2(q, k, v, s) + + if need_pad: + o = o[:, :, :, :e] + + return o + + +debug = eval(os.environ.get("debug", default="False")) + +BLOCK = 256 + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01 +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +def get_activation_fn(activation): + if debug: + logger.info(f"activation: {activation}") + if activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + elif activation == "elu": + return F.elu + elif activation == "sigmoid": + return F.sigmoid + elif activation == "exp": + + def f(x): + with torch.no_grad(): + x_max = torch.max(x, dim=-1, keepdims=True).values + y = torch.exp(x - x_max) + + return y + + return f + elif activation == "leak": + return F.leaky_relu + elif activation == "1+elu": + + def f(x): + return 1 + F.elu(x) + + return f + elif activation == "2+elu": + + def f(x): + return 2 + F.elu(x) + + return f + elif activation == "silu" or activation == "swish": + return F.silu + elif activation == "sine": + return torch.sin + else: + logger.info(f"activation: does not support {activation}, use Identity!!!") + return lambda x: x + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs): + super().__init__() + if config is None: + config = type("Config", (), kwargs) + + bias = False + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.out_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=bias + ) + self.act = get_activation_fn(config.hidden_act) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) + + self.qkv_proj = nn.Linear( + self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias + ) + self.output_gate = nn.Linear( + self.hidden_size, self.head_dim * self.num_heads, bias=bias + ) + + # for inference only + self.offset = 0 + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, + **kwargs, + ): + if (not self.training) and (not do_eval): + return self.inference( + hidden_states, + attn_mask, + output_attentions, + past_key_value, + use_cache, + slope_rate, + ) + + def inference( + self, + x, + attn_mask: Optional[torch.Tensor] = None, # (b, n) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) + ): + # x: b n d + b, n, d = x.shape + # linear map + qkv = self.act(self.qkv_proj(x)) + new_shape = qkv.size()[:-1] + (self.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if past_key_value is None: + self.offset = q.shape[-2] + else: + self.offset += 1 + + # for align with metaseq + ratio = torch.exp(-slope_rate) + + # only use for the first time + if past_key_value is None: + slope_rate = slope_rate.to(torch.float32) + if attn_mask is not None: + v = v.masked_fill( + (1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0 + ) + NUM_BLOCK = (n + BLOCK - 1) // BLOCK + b, h, n, d = q.shape + e = v.shape[-1] + # other + array = torch.arange(BLOCK).to(q) + 1 + q_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) + k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + slope_rate + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + for i in range(NUM_BLOCK): + si = i * BLOCK + ei = min(si + BLOCK, n) + m = ei - si + qi = q[:, :, si:ei].contiguous() + ki = k[:, :, si:ei].contiguous() + vi = v[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) + + # diag + qk = ( + torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) + * diag_decay[:, :, :m, :m] + ) + qkv_diag = torch.matmul(qk, vi.to(torch.float32)) + block_decay = torch.exp(-slope_rate * m) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + kv = block_decay * kv + torch.matmul( + (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi + ) + + else: + kv = past_key_value + output = [] + for i in range(n): + kv = ratio * kv + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + # reshape + output = rearrange(output, "b h n d -> b n (h d)") + # normalize + output = self.norm(output) + # gate + output = F.sigmoid(self.output_gate(x)) * output + # outproj + output = self.out_proj(output) + + attn_weights = None + + return output, attn_weights, kv + + +def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + # h, 1, 1 + slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( + n_attention_heads, 1, 1 + ) + + return slopes + + +def test_lightning_attention_implementations(model_params): + torch.manual_seed(42) + + batch_size = 2 + seq_len = 1024 + dtype = torch.bfloat16 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + hidden_states = torch.randn( + batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device) + + model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device) + model_attn.eval() + + with torch.no_grad(): + model_output, _, _ = model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ) + + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + lib_output = model_attn.out_proj(lib_output) + + torch.testing.assert_close( + model_output, + lib_output, + rtol=1e-3, + atol=1e-2, + msg="Lightning attention implementations produce different results", + ) + + +def get_benchmark(): + batch_size_range = [2**i for i in range(0, 7)] # max 64 + seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096 + configs = list(itertools.product(batch_size_range, seq_length_range)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["MiniMax-Text-01", "OpenNLPLab"], + line_names=[ + "MiniMax-Text-01 Model Implementation", + "OpenNLPLab Library Implementation", + ], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-prefill-performance", + args={}, + ) + ) + def benchmark(batch_size, seq_len, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "gelu", + } + + hidden_states = torch.randn( + batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device) + model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device) + model_attn.eval() + + quantiles = [0.5, 0.2, 0.8] + if provider == "MiniMax-Text-01": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ), + quantiles=quantiles, + ) + else: + + def run_lib(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = ( + torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + ) + return model_attn.out_proj(lib_output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_lib, + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_prefill/", + help="Path to save lightning attention prefill benchmark results", + ) + args = parser.parse_args() + + # Run correctness test first + # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "silu", + } + test_lightning_attention_implementations(params) + + # Run performance benchmark + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path)