bench: add attention sink op benchmark, triton and trtllm-gen [B200] (#8932)
Co-authored-by: averyhuang <averyh@nvidia.com>
This commit is contained in:
250
benchmark/bench_attention_sink/bench_attention_sink_triton.py
Normal file
250
benchmark/bench_attention_sink/bench_attention_sink_triton.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd_grouped,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||
|
||||
# gpt oss
|
||||
head_num = 64
|
||||
head_dim = 64
|
||||
head_kv_num = 8
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["S"], # sequence length on x-axis
|
||||
x_vals=[128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=True,
|
||||
line_arg="B", # batch size as different lines
|
||||
line_vals=[1, 8, 32, 128],
|
||||
line_names=["B=1", "B=8", "B=32", "B=128"],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
("red", "-"),
|
||||
("cyan", "-"),
|
||||
],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="attention-sink-triton-decode",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_decode(B, S, H_Q, H_KV, D):
|
||||
D_V = D
|
||||
dtype = torch.bfloat16
|
||||
seq_len = S
|
||||
total_tokens = B * seq_len
|
||||
device = torch.device("cuda")
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
max_kv_splits = 8
|
||||
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits1 = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse1 = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
attn_lse1,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=sink,
|
||||
)
|
||||
|
||||
# benchmark
|
||||
run_step = 500
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
attn_lse1,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=sink,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms # must be causal
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["S"], # sequence length on x-axis
|
||||
x_vals=[128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=True,
|
||||
line_arg="B", # batch size as different lines
|
||||
line_vals=[1, 8, 32, 128],
|
||||
line_names=["B=1", "B=8", "B=32", "B=128"],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
("red", "-"),
|
||||
("cyan", "-"),
|
||||
],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="attention-sink-triton-extend",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_extend(B, S, H_Q, H_KV, D):
|
||||
# S here represents N_CTX from the test
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
|
||||
# Split S into prefix and extend lengths
|
||||
prefill_len = S // 2 # Similar to test's N_CTX // 2
|
||||
extend_len = S // 4 # Make extend length smaller than prefix
|
||||
|
||||
# Calculate total tokens and extend tokens
|
||||
total_extend_tokens = B * extend_len
|
||||
total_prefix_tokens = B * prefill_len
|
||||
|
||||
# Create query, key, value tensors for extension
|
||||
q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)
|
||||
k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
o_extend = torch.empty_like(q_extend)
|
||||
|
||||
# Create key-value buffers for prefix
|
||||
k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
|
||||
# Create index pointers
|
||||
qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(
|
||||
torch.int32
|
||||
)
|
||||
kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(
|
||||
torch.int32
|
||||
)
|
||||
kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)
|
||||
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
# sliding_window = 128 # From GPT-OSS config, skip for now
|
||||
sliding_window = -1
|
||||
|
||||
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=extend_len,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window_size=sliding_window,
|
||||
sinks=sink,
|
||||
)
|
||||
|
||||
# benchmark
|
||||
run_step = 500
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=extend_len,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window_size=sliding_window,
|
||||
sinks=sink,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
|
||||
# FLOPS calculation: each attention operation requires 2 multiplications per element
|
||||
total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D
|
||||
tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3) # convert to TFLOPS
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bench", type=str, default="all", help="all, extend, decode")
|
||||
args = parser.parse_args()
|
||||
|
||||
kwargs = {
|
||||
"H_Q": head_num,
|
||||
"H_KV": head_kv_num,
|
||||
"D": head_dim,
|
||||
}
|
||||
|
||||
if args.bench in ["all", "decode"]:
|
||||
benchmark_decode.run(print_data=True, show_plots=False, **kwargs)
|
||||
|
||||
if args.bench in ["all", "extend"]:
|
||||
benchmark_extend.run(print_data=True, show_plots=False, **kwargs)
|
||||
|
||||
print("Benchmark finished!")
|
||||
Reference in New Issue
Block a user