diff --git a/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py new file mode 100644 index 000000000..3692b5b39 --- /dev/null +++ b/sgl-kernel/benchmark/bench_top_k_top_p_sampling.py @@ -0,0 +1,128 @@ +import itertools + +import sgl_kernel +import torch +import triton +import triton.testing + + +def torch_top_k_top_p_joint_sampling_from_probs( + normalized_prob, top_k, top_p, eps=1e-4 +): + """Reference PyTorch implementation of joint top-k top-p sampling.""" + batch_size, vocab_size = normalized_prob.shape + samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device) + + for i in range(batch_size): + p_val = top_p[i].item() + k_val = top_k[i].item() + + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob[i], descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros( + vocab_size, dtype=torch.int32, device=normalized_prob.device + ) + mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int()) + + # top-k mask + sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True) + pivot = sorted_prob_desc[k_val - 1] + mask_top_k = (normalized_prob[i] >= pivot).int() + + # joint mask + mask = torch.minimum(mask_top_p, mask_top_k).bool() + + # sample from masked probs + masked_probs = normalized_prob[i] * mask + masked_probs = masked_probs / masked_probs.sum() + idx = torch.multinomial(masked_probs, 1) + samples[i] = idx + + return samples + + +def calculate_diff(batch_size, vocab_size, p): + """Compare Torch reference and SGLang kernel for correctness.""" + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + + device = torch.device("cuda") + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + top_p_tensor = torch.full((batch_size,), p, device=device) + top_k_tensor = torch.full((batch_size,), k, device=device) + + torch_samples = torch_top_k_top_p_joint_sampling_from_probs( + normalized_prob, top_k_tensor, top_p_tensor + ) + sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint" + ) + + +# parameter space +batch_size_range = [16, 64, 128] +vocab_size_range = [111, 32000] +p_range = [0.1, 0.5] +configs = list(itertools.product(batch_size_range, vocab_size_range, p_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size", "p"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "sglang"], + line_names=["Torch Reference", "SGL Kernel"], + styles=[("red", "-"), ("green", "-")], + ylabel="us", + plot_name="top-k-top-p-joint-sampling-performance", + args={}, + ) +) +def benchmark_sampling(batch_size, vocab_size, p, provider): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + + device = torch.device("cuda") + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + top_p_tensor = torch.full((batch_size,), p, device=device) + top_k_tensor = torch.full((batch_size,), k, device=device) + + if provider == "torch": + fn = lambda: torch_top_k_top_p_joint_sampling_from_probs( + normalized_prob.clone(), top_k_tensor, top_p_tensor + ) + elif provider == "sglang": + fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob.clone(), + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", + ) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + # Correctness check + for cfg in configs: + calculate_diff(*cfg) + + print("\n" + "=" * 60) + print("Starting performance benchmark...") + benchmark_sampling.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 5835146a2..1915f176e 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -345,15 +345,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()"); m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + m.def( "top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, " "float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); - m.def( - "top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " - "maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); - m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); + m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); + m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index b9a694ac8..8d268e82b 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -593,6 +593,10 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, std::optional gen); + +void top_k_mask_logits( + at::Tensor logits, at::Tensor mask_logits, std::optional maybe_top_k_arr, int64_t top_k_val); + torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index ce0023b8a..ee7c36541 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -85,7 +85,9 @@ from sgl_kernel.moe import ( ) from sgl_kernel.sampling import ( min_p_sampling_from_probs, + top_k_mask_logits, top_k_renorm_prob, + top_k_top_p_sampling_from_logits, top_k_top_p_sampling_from_probs, top_p_renorm_prob, top_p_sampling_from_probs, diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index d4856e52c..489093751 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Tuple, Union import torch from sgl_kernel.utils import _to_tensor_scalar_tuple @@ -383,3 +383,161 @@ def min_p_sampling_from_probs( return _min_p_sampling_from_probs_internal( probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator ) + + +def _top_k_mask_logits_internal( + logits: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + logits = logits.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + mask_logits = torch.empty_like(logits) + torch.ops.sgl_kernel.top_k_mask_logits.default( + logits, mask_logits, maybe_top_k_arr, top_k_val + ) + return mask_logits + + +def top_k_mask_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for masking logits by top-k thresholding. + + Parameters + ---------- + logits: torch.Tensor + Logits before softmax, shape ``(batch_size, num_classes)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for + for masking logits, should be in ``(0, num_classes)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + We keep the top-k logits, set the rest to negative infinity. + + Returns + ------- + masked_logits: torch.Tensor + Masked logits, shape ``(batch_size, num_classes)``. + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> torch.manual_seed(42) + >>> batch_size = 4 + >>> vocab_size = 5 + >>> top_k = 3 + >>> logits = torch.randn(batch_size, vocab_size).to(0) + >>> logits + tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], + [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], + [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], + [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') + >>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) + >>> masked_logits + tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf], + [ 1.0783, 0.8008, 1.6806, -inf, -inf], + [ -inf, 0.2415, -0.2316, 0.0418, -inf], + [ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0') + + Note + ---- + The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``. + + See Also + -------- + top_k_renorm_probs + """ + return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k)) + + +def top_k_top_p_sampling_from_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, +) -> torch.Tensor: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for top-k and top-p sampling from probabilities, + + this operator implements GPU-based rejection sampling without explicit sorting. + Check the `blog post `_ for more details. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + logits: torch.Tensor + Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` + and the i-th output will be sampled from the i-th row of logits. When indices is provided, + shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique + probability distributions. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + indices: Optional[torch.Tensor] + Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. + This allows reusing the same probability distribution for multiple outputs. + If indices is not provided, the i-th output will be sampled from the i-th row of probs. + filter_apply_order: str + The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. + If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. + If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. + generator: Optional[torch.Generator] + A random number generator for the operation. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. + + Returns + ------- + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + + Note + ---- + This function expects float32 inputs, and the output is int32. + + """ + if filter_apply_order == "top_k_first": + masked_logits = top_k_mask_logits(logits, top_k) + probs = torch.softmax(masked_logits, dim=-1) + return top_p_sampling_from_probs( + probs, + top_p, + indices, + deterministic, + check_nan=check_nan, + generator=generator, + ) + elif filter_apply_order == "joint": + probs = torch.softmax(logits, dim=-1) + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_k_top_p_sampling_from_probs_internal( + probs, + indices, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + generator, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py index 14f41e5ef..dc5734cb7 100644 --- a/sgl-kernel/tests/test_sampling.py +++ b/sgl-kernel/tests/test_sampling.py @@ -5,6 +5,54 @@ import sgl_kernel import torch +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("k", [100]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment( + batch_size, vocab_size, k, p +): + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + generator_logits = torch.Generator("cuda:0") + generator_probs = generator_logits.clone_state() + samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits( + logits, k, p, filter_apply_order="top_k_first", generator=generator_logits + ) + samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs( + torch.softmax(logits, dim=-1), + k, + p, + filter_apply_order="top_k_first", + generator=generator_probs, + ) + assert torch.all(samples == samples_ref) + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("k", [100]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_sampling_from_probs_logits_joint_alignment( + batch_size, vocab_size, k, p +): + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + generator_logits = torch.Generator("cuda:0") + generator_probs = generator_logits.clone_state() + samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits( + logits, k, p, filter_apply_order="joint", generator=generator_logits + ) + samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs( + torch.softmax(logits, dim=-1), + k, + p, + filter_apply_order="joint", + generator=generator_probs, + ) + assert torch.all(samples == samples_ref) + + @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5])