129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
import itertools
|
|
|
|
import sgl_kernel
|
|
import torch
|
|
import triton
|
|
import triton.testing
|
|
|
|
|
|
def torch_top_k_top_p_joint_sampling_from_probs(
|
|
normalized_prob, top_k, top_p, eps=1e-4
|
|
):
|
|
"""Reference PyTorch implementation of joint top-k top-p sampling."""
|
|
batch_size, vocab_size = normalized_prob.shape
|
|
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
|
|
|
|
for i in range(batch_size):
|
|
p_val = top_p[i].item()
|
|
k_val = top_k[i].item()
|
|
|
|
# top-p mask
|
|
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
|
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
|
mask_top_p = torch.zeros(
|
|
vocab_size, dtype=torch.int32, device=normalized_prob.device
|
|
)
|
|
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
|
|
|
|
# top-k mask
|
|
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
|
|
pivot = sorted_prob_desc[k_val - 1]
|
|
mask_top_k = (normalized_prob[i] >= pivot).int()
|
|
|
|
# joint mask
|
|
mask = torch.minimum(mask_top_p, mask_top_k).bool()
|
|
|
|
# sample from masked probs
|
|
masked_probs = normalized_prob[i] * mask
|
|
masked_probs = masked_probs / masked_probs.sum()
|
|
idx = torch.multinomial(masked_probs, 1)
|
|
samples[i] = idx
|
|
|
|
return samples
|
|
|
|
|
|
def calculate_diff(batch_size, vocab_size, p):
|
|
"""Compare Torch reference and SGLang kernel for correctness."""
|
|
torch.manual_seed(42)
|
|
if p == 0.1:
|
|
k = int(vocab_size * 0.5)
|
|
elif p == 0.5:
|
|
k = int(vocab_size * 0.1)
|
|
else:
|
|
raise ValueError("p not recognized")
|
|
|
|
device = torch.device("cuda")
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
|
|
top_p_tensor = torch.full((batch_size,), p, device=device)
|
|
top_k_tensor = torch.full((batch_size,), k, device=device)
|
|
|
|
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
|
|
normalized_prob, top_k_tensor, top_p_tensor
|
|
)
|
|
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
|
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
|
|
)
|
|
|
|
|
|
# parameter space
|
|
batch_size_range = [16, 64, 128]
|
|
vocab_size_range = [111, 32000]
|
|
p_range = [0.1, 0.5]
|
|
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size", "vocab_size", "p"],
|
|
x_vals=configs,
|
|
line_arg="provider",
|
|
line_vals=["torch", "sglang"],
|
|
line_names=["Torch Reference", "SGL Kernel"],
|
|
styles=[("red", "-"), ("green", "-")],
|
|
ylabel="us",
|
|
plot_name="top-k-top-p-joint-sampling-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark_sampling(batch_size, vocab_size, p, provider):
|
|
torch.manual_seed(42)
|
|
if p == 0.1:
|
|
k = int(vocab_size * 0.5)
|
|
elif p == 0.5:
|
|
k = int(vocab_size * 0.1)
|
|
else:
|
|
raise ValueError("p not recognized")
|
|
|
|
device = torch.device("cuda")
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
top_p_tensor = torch.full((batch_size,), p, device=device)
|
|
top_k_tensor = torch.full((batch_size,), k, device=device)
|
|
|
|
if provider == "torch":
|
|
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
|
|
normalized_prob.clone(), top_k_tensor, top_p_tensor
|
|
)
|
|
elif provider == "sglang":
|
|
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
|
|
normalized_prob.clone(),
|
|
top_k_tensor,
|
|
top_p_tensor,
|
|
filter_apply_order="joint",
|
|
)
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Correctness check
|
|
for cfg in configs:
|
|
calculate_diff(*cfg)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Starting performance benchmark...")
|
|
benchmark_sampling.run(print_data=True)
|