282 lines
7.8 KiB
Python
282 lines
7.8 KiB
Python
"""
|
|
Unit tests comparing TileLang and Triton implementations of activation quantization.
|
|
Tests both accuracy and performance.
|
|
"""
|
|
|
|
import time
|
|
from typing import Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
|
from sglang.srt.layers.attention.nsa.triton_kernel import act_quant as act_quant_triton
|
|
|
|
|
|
def benchmark_kernel(
|
|
fn,
|
|
x: torch.Tensor,
|
|
block_size: int,
|
|
scale_fmt,
|
|
warmup: int = 10,
|
|
repeat: int = 100,
|
|
use_cuda_graph: bool = True,
|
|
) -> Tuple[float, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Benchmark a kernel function.
|
|
|
|
Args:
|
|
fn: Function to benchmark
|
|
x: Input tensor
|
|
block_size: Block size for quantization
|
|
scale_fmt: Scale format
|
|
warmup: Number of warmup iterations
|
|
repeat: Number of repeat iterations
|
|
use_cuda_graph: Whether to use CUDA graphs for more accurate timing
|
|
|
|
Returns:
|
|
Tuple of (avg_time_ms, quantized_output, scales)
|
|
"""
|
|
# Warmup
|
|
for _ in range(warmup):
|
|
y, s = fn(x, block_size=block_size, scale_fmt=scale_fmt)
|
|
|
|
if not x.is_cuda or not use_cuda_graph:
|
|
# Fallback to regular timing
|
|
if x.is_cuda:
|
|
torch.cuda.synchronize()
|
|
|
|
start = time.perf_counter()
|
|
for _ in range(repeat):
|
|
y, s = fn(x, block_size=block_size, scale_fmt=scale_fmt)
|
|
|
|
if x.is_cuda:
|
|
torch.cuda.synchronize()
|
|
|
|
end = time.perf_counter()
|
|
avg_time_ms = (end - start) / repeat * 1000
|
|
|
|
return avg_time_ms, y, s
|
|
|
|
# Use CUDA graph for more accurate timing
|
|
torch.cuda.synchronize()
|
|
|
|
# Allocate output buffers
|
|
N = x.size(-1)
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
|
|
|
# Capture CUDA graph
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
y_cap, s_cap = fn(x, block_size=block_size, scale_fmt=scale_fmt)
|
|
|
|
# Warmup with graph
|
|
for _ in range(warmup):
|
|
graph.replay()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# Timing with CUDA graph
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
start_event.record()
|
|
for _ in range(repeat):
|
|
graph.replay()
|
|
end_event.record()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
avg_time_ms = start_event.elapsed_time(end_event) / repeat
|
|
|
|
return avg_time_ms, y_cap, s_cap
|
|
|
|
|
|
def check_accuracy(
|
|
y_ref: torch.Tensor,
|
|
s_ref: torch.Tensor,
|
|
y_test: torch.Tensor,
|
|
s_test: torch.Tensor,
|
|
rtol: float = 1e-2,
|
|
atol: float = 1e-2,
|
|
) -> Tuple[bool, dict]:
|
|
"""
|
|
Check accuracy between reference and test outputs.
|
|
|
|
Args:
|
|
y_ref: Reference quantized output
|
|
s_ref: Reference scales
|
|
y_test: Test quantized output
|
|
s_test: Test scales
|
|
rtol: Relative tolerance
|
|
atol: Absolute tolerance
|
|
|
|
Returns:
|
|
Tuple of (passed, metrics_dict)
|
|
"""
|
|
# Convert FP8 to float for comparison
|
|
y_ref_float = y_ref.float()
|
|
y_test_float = y_test.float()
|
|
|
|
# Compute differences
|
|
y_diff = torch.abs(y_ref_float - y_test_float)
|
|
s_diff = torch.abs(s_ref - s_test)
|
|
|
|
# Compute metrics
|
|
y_max_diff = y_diff.max().item()
|
|
y_mean_diff = y_diff.mean().item()
|
|
s_max_diff = s_diff.max().item()
|
|
s_mean_diff = s_diff.mean().item()
|
|
|
|
# Check relative and absolute tolerance
|
|
y_close = torch.allclose(y_ref_float, y_test_float, rtol=rtol, atol=atol)
|
|
s_close = torch.allclose(s_ref, s_test, rtol=rtol, atol=atol)
|
|
|
|
# Compute percentage of matching elements
|
|
y_match_pct = (y_ref_float == y_test_float).float().mean().item() * 100
|
|
|
|
metrics = {
|
|
"y_max_diff": y_max_diff,
|
|
"y_mean_diff": y_mean_diff,
|
|
"y_match_pct": y_match_pct,
|
|
"s_max_diff": s_max_diff,
|
|
"s_mean_diff": s_mean_diff,
|
|
"y_close": y_close,
|
|
"s_close": s_close,
|
|
}
|
|
|
|
passed = y_close and s_close
|
|
|
|
return passed, metrics
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
def test_act_quant_comprehensive_benchmark(scale_fmt=None):
|
|
"""Comprehensive benchmark across multiple sizes with CUDA graphs."""
|
|
device = torch.device("cuda")
|
|
dtype = torch.bfloat16
|
|
block_size = 128
|
|
|
|
shapes = [
|
|
(128, 512),
|
|
(256, 1024),
|
|
(512, 2048),
|
|
(1024, 4096),
|
|
(2048, 8192),
|
|
(4096, 16384),
|
|
]
|
|
|
|
print("\n" + "=" * 100)
|
|
print("Comprehensive Performance Benchmark with CUDA Graphs")
|
|
print("=" * 100)
|
|
print(
|
|
f"{'Shape':<20} {'TileLang (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'Status'}"
|
|
)
|
|
print("-" * 100)
|
|
|
|
for shape in shapes:
|
|
torch.manual_seed(42)
|
|
x = torch.randn(shape, dtype=dtype, device=device)
|
|
|
|
try:
|
|
# Benchmark both with CUDA graphs
|
|
time_tilelang, y_ref, s_ref = benchmark_kernel(
|
|
act_quant,
|
|
x,
|
|
block_size,
|
|
scale_fmt,
|
|
warmup=5,
|
|
repeat=50,
|
|
use_cuda_graph=True,
|
|
)
|
|
time_triton, y_triton, s_triton = benchmark_kernel(
|
|
act_quant_triton,
|
|
x,
|
|
block_size,
|
|
scale_fmt,
|
|
warmup=5,
|
|
repeat=50,
|
|
use_cuda_graph=True,
|
|
)
|
|
|
|
# Check accuracy
|
|
passed, _ = check_accuracy(y_ref, s_ref, y_triton, s_triton)
|
|
|
|
speedup = time_tilelang / time_triton if time_triton > 0 else 0
|
|
status = "✓ PASS" if passed else "✗ FAIL"
|
|
|
|
print(
|
|
f"{str(shape):<20} {time_tilelang:<15.4f} {time_triton:<15.4f} "
|
|
f"{speedup:<10.2f} {status}"
|
|
)
|
|
except Exception as e:
|
|
print(f"{str(shape):<20} ERROR: {str(e)}")
|
|
|
|
print("=" * 100)
|
|
|
|
# Also run without CUDA graphs for comparison
|
|
print("\n" + "=" * 100)
|
|
print("Performance Benchmark WITHOUT CUDA Graphs (for comparison)")
|
|
print("=" * 100)
|
|
print(
|
|
f"{'Shape':<20} {'TileLang (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'Status'}"
|
|
)
|
|
print("-" * 100)
|
|
|
|
for shape in shapes:
|
|
torch.manual_seed(42)
|
|
x = torch.randn(shape, dtype=dtype, device=device)
|
|
|
|
try:
|
|
# Benchmark both without CUDA graphs
|
|
time_tilelang, y_ref, s_ref = benchmark_kernel(
|
|
act_quant,
|
|
x,
|
|
block_size,
|
|
scale_fmt,
|
|
warmup=5,
|
|
repeat=50,
|
|
use_cuda_graph=False,
|
|
)
|
|
time_triton, y_triton, s_triton = benchmark_kernel(
|
|
act_quant_triton,
|
|
x,
|
|
block_size,
|
|
scale_fmt,
|
|
warmup=5,
|
|
repeat=50,
|
|
use_cuda_graph=False,
|
|
)
|
|
|
|
# Check accuracy
|
|
passed, _ = check_accuracy(y_ref, s_ref, y_triton, s_triton)
|
|
|
|
speedup = time_tilelang / time_triton if time_triton > 0 else 0
|
|
status = "✓ PASS" if passed else "✗ FAIL"
|
|
|
|
print(
|
|
f"{str(shape):<20} {time_tilelang:<15.4f} {time_triton:<15.4f} "
|
|
f"{speedup:<10.2f} {status}"
|
|
)
|
|
except Exception as e:
|
|
print(f"{str(shape):<20} ERROR: {str(e)}")
|
|
|
|
print("=" * 100)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run comprehensive benchmark
|
|
if torch.cuda.is_available():
|
|
print("\n" + "=" * 80)
|
|
print("Running Comprehensive Benchmark with scale_fmt=None")
|
|
print("=" * 80)
|
|
test_act_quant_comprehensive_benchmark(scale_fmt=None)
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Running Comprehensive Benchmark with scale_fmt!=None")
|
|
print("=" * 80)
|
|
test_act_quant_comprehensive_benchmark(scale_fmt="any")
|
|
else:
|
|
print("CUDA not available. Skipping tests.")
|