160 lines
4.6 KiB
Python
160 lines
4.6 KiB
Python
import itertools
|
|
|
|
import torch
|
|
import triton
|
|
from sgl_kernel import sampling_scaling_penalties
|
|
|
|
|
|
def sampling_scaling_penalties_naive(logits, scaling_penalties):
|
|
return torch.where(
|
|
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
|
)
|
|
|
|
|
|
def sampling_scaling_penalties_kernel(logits, scaling_penalties):
|
|
return sampling_scaling_penalties(logits, scaling_penalties)
|
|
|
|
|
|
def test_memory(func, _iter):
|
|
total_mem = []
|
|
|
|
for _ in range(_iter):
|
|
torch.cuda.memory.reset_peak_memory_stats()
|
|
func()
|
|
mem = torch.cuda.max_memory_allocated() / (2**20)
|
|
total_mem.append(mem)
|
|
|
|
return sum(total_mem) / len(total_mem)
|
|
|
|
|
|
def calculate_diff(batch_size, vocab_size):
|
|
dtype = torch.bfloat16
|
|
device = torch.device("cuda")
|
|
|
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
|
scaling_penalties = (
|
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
|
)
|
|
|
|
output_naive = sampling_scaling_penalties_naive(
|
|
logits.clone(), scaling_penalties.clone()
|
|
)
|
|
output_kernel = sampling_scaling_penalties_kernel(
|
|
logits.clone(), scaling_penalties.clone()
|
|
)
|
|
|
|
print(f"Naive output={output_naive}")
|
|
print(f"Kernel output={output_kernel}")
|
|
|
|
if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2):
|
|
print("✅ Both implementations match")
|
|
else:
|
|
print("❌ Implementations differ")
|
|
|
|
|
|
batch_size_range = [2**i for i in range(0, 12)]
|
|
vocab_size_range = [2**i for i in range(10, 17)]
|
|
configs = list(itertools.product(batch_size_range, vocab_size_range))
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size", "vocab_size"],
|
|
x_vals=[list(_) for _ in configs],
|
|
line_arg="provider",
|
|
line_vals=["naive", "kernel"],
|
|
line_names=["PyTorch Naive", "SGL Kernel"],
|
|
styles=[("blue", "-"), ("red", "-")],
|
|
ylabel="us",
|
|
plot_name="sampling-scaling-penalties-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(batch_size, vocab_size, provider):
|
|
dtype = torch.bfloat16
|
|
device = torch.device("cuda")
|
|
|
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
|
scaling_penalties = (
|
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
|
)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "naive":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: sampling_scaling_penalties_naive(
|
|
logits.clone(),
|
|
scaling_penalties.clone(),
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
else:
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: sampling_scaling_penalties_kernel(
|
|
logits.clone(),
|
|
scaling_penalties.clone(),
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size", "vocab_size"],
|
|
x_vals=[list(_) for _ in configs],
|
|
line_arg="provider",
|
|
line_vals=["naive", "kernel"],
|
|
line_names=["PyTorch Naive", "SGL Kernel"],
|
|
styles=[("blue", "-"), ("red", "-")],
|
|
ylabel="GPU memory usage (MB)",
|
|
plot_name="sampling-scaling-penalties-memory",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark_memory(batch_size, vocab_size, provider):
|
|
dtype = torch.bfloat16
|
|
device = torch.device("cuda")
|
|
|
|
print(
|
|
f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}"
|
|
)
|
|
|
|
def run_kernel():
|
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
|
scaling_penalties = (
|
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
|
)
|
|
|
|
if provider == "naive":
|
|
return sampling_scaling_penalties_naive(logits, scaling_penalties)
|
|
else:
|
|
return sampling_scaling_penalties_kernel(logits, scaling_penalties)
|
|
|
|
mem = test_memory(run_kernel, _iter=10)
|
|
return mem
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--save_path",
|
|
type=str,
|
|
default="./configs/benchmark_ops/sampling_scaling_penalties/",
|
|
help="Path to save sampling_scaling_penalties benchmark results",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Run correctness test
|
|
calculate_diff(batch_size=4, vocab_size=4096)
|
|
|
|
# Run performance benchmark
|
|
benchmark.run(print_data=True, save_path=args.save_path)
|
|
|
|
# Run memory benchmark
|
|
benchmark_memory.run(print_data=True, save_path=args.save_path)
|