From 9c7e392465d96b2deb88e40e92a1ecc6c05e0e58 Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Fri, 8 Aug 2025 03:16:23 -0400 Subject: [PATCH] bench: add attention sink op benchmark, triton and trtllm-gen [B200] (#8932) Co-authored-by: averyhuang --- .../bench_attention_sink_triton.py | 250 ++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 benchmark/bench_attention_sink/bench_attention_sink_triton.py diff --git a/benchmark/bench_attention_sink/bench_attention_sink_triton.py b/benchmark/bench_attention_sink/bench_attention_sink_triton.py new file mode 100644 index 000000000..21bc2a593 --- /dev/null +++ b/benchmark/bench_attention_sink/bench_attention_sink_triton.py @@ -0,0 +1,250 @@ +import argparse + +import torch +import triton + +from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd_grouped, +) +from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd + +# gpt oss +head_num = 64 +head_dim = 64 +head_kv_num = 8 + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["S"], # sequence length on x-axis + x_vals=[128, 256, 512, 1024, 2048, 4096], + x_log=True, + line_arg="B", # batch size as different lines + line_vals=[1, 8, 32, 128], + line_names=["B=1", "B=8", "B=32", "B=128"], + styles=[ + ("blue", "-"), + ("green", "-"), + ("red", "-"), + ("cyan", "-"), + ], + ylabel="TFLOPS", + plot_name="attention-sink-triton-decode", + args={}, + ) +) +def benchmark_decode(B, S, H_Q, H_KV, D): + D_V = D + dtype = torch.bfloat16 + seq_len = S + total_tokens = B * seq_len + device = torch.device("cuda") + sm_scale = 1.0 / (D**0.5) + max_kv_splits = 8 + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + + b_seq_len = torch.full((B,), seq_len, device="cuda") + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + + attn_logits1 = torch.empty( + (B, H_Q, max_kv_splits, D_V), + dtype=torch.float32, + device="cuda", + ) + attn_lse1 = torch.empty( + (B, H_Q, max_kv_splits, D_V), + dtype=torch.float32, + device="cuda", + ) + sink = torch.randn(H_Q, device=device, dtype=torch.float32) + + # warmup + for _ in range(5): + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits1, + attn_lse1, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap=0.0, + sinks=sink, + ) + + # benchmark + run_step = 500 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(run_step): + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits1, + attn_lse1, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap=0.0, + sinks=sink, + ) + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + ms = start_event.elapsed_time(end_event) / run_step + tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms # must be causal + return tflops(ms) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["S"], # sequence length on x-axis + x_vals=[128, 256, 512, 1024, 2048, 4096], + x_log=True, + line_arg="B", # batch size as different lines + line_vals=[1, 8, 32, 128], + line_names=["B=1", "B=8", "B=32", "B=128"], + styles=[ + ("blue", "-"), + ("green", "-"), + ("red", "-"), + ("cyan", "-"), + ], + ylabel="TFLOPS", + plot_name="attention-sink-triton-extend", + args={}, + ) +) +def benchmark_extend(B, S, H_Q, H_KV, D): + # S here represents N_CTX from the test + dtype = torch.bfloat16 + device = "cuda" + + # Split S into prefix and extend lengths + prefill_len = S // 2 # Similar to test's N_CTX // 2 + extend_len = S // 4 # Make extend length smaller than prefix + + # Calculate total tokens and extend tokens + total_extend_tokens = B * extend_len + total_prefix_tokens = B * prefill_len + + # Create query, key, value tensors for extension + q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device) + k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device) + v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device) + o_extend = torch.empty_like(q_extend) + + # Create key-value buffers for prefix + k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device) + v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device) + + # Create index pointers + qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to( + torch.int32 + ) + kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to( + torch.int32 + ) + kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32) + + sm_scale = 1.0 / (D**0.5) + # sliding_window = 128 # From GPT-OSS config, skip for now + sliding_window = -1 + + sink = torch.randn(H_Q, device=device, dtype=torch.float32) + + # warmup + for _ in range(5): + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask=None, + is_causal=True, + mask_indptr=None, + max_len_extend=extend_len, + sm_scale=sm_scale, + sliding_window_size=sliding_window, + sinks=sink, + ) + + # benchmark + run_step = 500 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(run_step): + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask=None, + is_causal=True, + mask_indptr=None, + max_len_extend=extend_len, + sm_scale=sm_scale, + sliding_window_size=sliding_window, + sinks=sink, + ) + end_event.record() + end_event.synchronize() + torch.cuda.synchronize() + ms = start_event.elapsed_time(end_event) / run_step + + # FLOPS calculation: each attention operation requires 2 multiplications per element + total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D + tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3) # convert to TFLOPS + return tflops(ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bench", type=str, default="all", help="all, extend, decode") + args = parser.parse_args() + + kwargs = { + "H_Q": head_num, + "H_KV": head_kv_num, + "D": head_dim, + } + + if args.bench in ["all", "decode"]: + benchmark_decode.run(print_data=True, show_plots=False, **kwargs) + + if args.bench in ["all", "extend"]: + benchmark_extend.run(print_data=True, show_plots=False, **kwargs) + + print("Benchmark finished!")