[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import itertools
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
):
|
||||
"""Reference PyTorch implementation of joint top-k top-p sampling."""
|
||||
batch_size, vocab_size = normalized_prob.shape
|
||||
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
p_val = top_p[i].item()
|
||||
k_val = top_k[i].item()
|
||||
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(
|
||||
vocab_size, dtype=torch.int32, device=normalized_prob.device
|
||||
)
|
||||
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
|
||||
|
||||
# top-k mask
|
||||
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
|
||||
pivot = sorted_prob_desc[k_val - 1]
|
||||
mask_top_k = (normalized_prob[i] >= pivot).int()
|
||||
|
||||
# joint mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k).bool()
|
||||
|
||||
# sample from masked probs
|
||||
masked_probs = normalized_prob[i] * mask
|
||||
masked_probs = masked_probs / masked_probs.sum()
|
||||
idx = torch.multinomial(masked_probs, 1)
|
||||
samples[i] = idx
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def calculate_diff(batch_size, vocab_size, p):
|
||||
"""Compare Torch reference and SGLang kernel for correctness."""
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor
|
||||
)
|
||||
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "vocab_size", "p"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "sglang"],
|
||||
line_names=["Torch Reference", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="top-k-top-p-joint-sampling-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob.clone(), top_k_tensor, top_p_tensor
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob.clone(),
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_sampling.run(print_data=True)
|
||||
Reference in New Issue
Block a user