Files
sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py

406 lines
12 KiB
Python
Raw Normal View History

import itertools
import math
import cudnn
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.utils import should_use_tensor_core
def benchmark_forward(
fn,
*inputs,
repeats=10,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
return t, m
def time_fwd(func, *args, **kwargs):
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean * 1e6
def decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits,
warmup=10,
):
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
o = torch.empty_like(q)
total_tokens = batch_size * kv_len
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
max_len_in_batch = kv_len
sm_scale = 1.0 / (head_dim**0.5)
attn_logits = torch.empty(
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
dtype=torch.float32,
device="cuda",
)
for _ in range(warmup):
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
f = time_fwd(
decode_attention_fwd,
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
return f, o
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=dtype,
num_attention_heads=head_num_q,
num_kv_heads=head_num_kv,
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
class FlashinferAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
dtype,
warmup=10,
):
total_tokens = batch_size * kv_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full(
(batch_size,), 1, dtype=torch.int32, device="cuda"
)
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
head_num_q,
head_num_kv,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=dtype,
)
for _ in range(warmup):
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, head_num_q, head_dim), kv_data
)
f = time_fwd(
flashinfer_decode_wrapper.forward,
q.contiguous().view(-1, head_num_q, head_dim),
kv_data,
)
return f, o
return FlashinferAttention
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
):
# Prepare data: continuous q,k,v
dims_q = (batch_size, head_num_q, 1, head_dim)
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
q_gpu = q.as_strided(dims_q, strides_q)
o_gpu = (
torch.empty(batch_size * head_num_q * head_dim)
.half()
.cuda()
.as_strided(dims_q, strides_q)
)
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
strides_kv = (
kv_len * head_num_kv * head_dim,
head_dim,
head_num_kv * head_dim,
1,
)
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
attn_scale = 1.0 / (head_dim**0.5)
# Prepare data: paged k,v
block_size = 1
blocks_per_batch = math.ceil(kv_len / block_size)
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
page_table_k_gpu = (
torch.linspace(
0,
batch_size * blocks_per_batch - 1,
batch_size * blocks_per_batch,
device="cuda",
dtype=torch.int32,
)
.reshape(blocks_per_batch, 1, batch_size, 1)
.transpose(0, 2)
)
page_table_v_gpu = page_table_k_gpu.clone()
graph = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q = graph.tensor_like(q_gpu)
container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
page_table_k = graph.tensor_like(page_table_k_gpu)
page_table_v = graph.tensor_like(page_table_v_gpu)
seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
o, _ = graph.sdpa(
name="sdpa",
q=q,
k=container_k, # Container K: non contiguous container with K blocks
v=container_v, # Container V: non contiguous container with V blocks
is_inference=True,
attn_scale=attn_scale,
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
)
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans()
workspace = torch.empty(
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
variant_pack = {
q: q_gpu,
container_k: container_k_gpu,
container_v: container_v_gpu,
page_table_k: page_table_k_gpu,
page_table_v: page_table_v_gpu,
seq_len_q: seq_len_q_gpu,
seq_len_kv: seq_len_kv_gpu,
o: o_gpu,
}
for _ in range(warmup):
graph.execute(variant_pack, workspace)
f = time_fwd(
graph.execute,
variant_pack,
workspace,
)
return f, o_gpu.squeeze(dim=2)
def calculate_diff():
dtype = torch.float16
batch_size = 64
kv_len = 4096
head_num_q = 64
head_num_kv = 8
head_dim = 128
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
kv_data = (
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
)
_, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
_, output_flashinfer = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
_, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(f"SGLang output={output_sglang}")
print(f"FlashInfer output={output_flashinfer}")
print(f"cuDNN output={output_cudnn}")
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and FlashInfer match")
else:
print("❌ SGLang[Triton] and FlashInfer differ")
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and cuDNN match")
else:
print("❌ SGLang[Triton] and cuDNN differ")
if __name__ == "__main__":
calculate_diff()
head_dim = 128
dtype = torch.float16
batch_size_range = [2**i for i in range(0, 8, 2)]
kv_len_range = [2**i for i in range(6, 13, 1)]
configs = list(itertools.product(batch_size_range, kv_len_range))
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
attn_flashinfer = decode_attention_flashinfer(
dtype, head_num_q, head_num_kv
).apply
for batch_size, kv_len in configs:
q = torch.randn(
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
)
kv_data = (
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
)
us_cudnn, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
us_sglang, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
us_flashinfer, _ = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(
head_num_q,
" ",
head_num_kv,
" ",
batch_size,
" ",
kv_len,
" ",
us_cudnn,
" ",
us_sglang,
" ",
us_flashinfer,
)