[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import itertools
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
):
|
||||
"""Reference PyTorch implementation of joint top-k top-p sampling."""
|
||||
batch_size, vocab_size = normalized_prob.shape
|
||||
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
p_val = top_p[i].item()
|
||||
k_val = top_k[i].item()
|
||||
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(
|
||||
vocab_size, dtype=torch.int32, device=normalized_prob.device
|
||||
)
|
||||
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
|
||||
|
||||
# top-k mask
|
||||
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
|
||||
pivot = sorted_prob_desc[k_val - 1]
|
||||
mask_top_k = (normalized_prob[i] >= pivot).int()
|
||||
|
||||
# joint mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k).bool()
|
||||
|
||||
# sample from masked probs
|
||||
masked_probs = normalized_prob[i] * mask
|
||||
masked_probs = masked_probs / masked_probs.sum()
|
||||
idx = torch.multinomial(masked_probs, 1)
|
||||
samples[i] = idx
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def calculate_diff(batch_size, vocab_size, p):
|
||||
"""Compare Torch reference and SGLang kernel for correctness."""
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor
|
||||
)
|
||||
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "vocab_size", "p"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "sglang"],
|
||||
line_names=["Torch Reference", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="top-k-top-p-joint-sampling-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob.clone(), top_k_tensor, top_p_tensor
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob.clone(),
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_sampling.run(print_data=True)
|
||||
@@ -345,15 +345,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
|
||||
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
|
||||
@@ -593,6 +593,10 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
|
||||
void top_k_mask_logits(
|
||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
|
||||
@@ -85,7 +85,9 @@ from sgl_kernel.moe import (
|
||||
)
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_mask_logits,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_logits,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||
@@ -383,3 +383,161 @@ def min_p_sampling_from_probs(
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
|
||||
)
|
||||
|
||||
|
||||
def _top_k_mask_logits_internal(
|
||||
logits: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
logits = logits.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
mask_logits = torch.empty_like(logits)
|
||||
torch.ops.sgl_kernel.top_k_mask_logits.default(
|
||||
logits, mask_logits, maybe_top_k_arr, top_k_val
|
||||
)
|
||||
return mask_logits
|
||||
|
||||
|
||||
def top_k_mask_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for masking logits by top-k thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Logits before softmax, shape ``(batch_size, num_classes)``.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||
for masking logits, should be in ``(0, num_classes)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We keep the top-k logits, set the rest to negative infinity.
|
||||
|
||||
Returns
|
||||
-------
|
||||
masked_logits: torch.Tensor
|
||||
Masked logits, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import torch
|
||||
>>> import flashinfer
|
||||
>>> torch.manual_seed(42)
|
||||
>>> batch_size = 4
|
||||
>>> vocab_size = 5
|
||||
>>> top_k = 3
|
||||
>>> logits = torch.randn(batch_size, vocab_size).to(0)
|
||||
>>> logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
|
||||
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
|
||||
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
|
||||
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
|
||||
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
|
||||
>>> masked_logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
|
||||
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
|
||||
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
|
||||
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
|
||||
|
||||
Note
|
||||
----
|
||||
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
top_k_renorm_probs
|
||||
"""
|
||||
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
filter_apply_order: str
|
||||
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if filter_apply_order == "top_k_first":
|
||||
masked_logits = top_k_mask_logits(logits, top_k)
|
||||
probs = torch.softmax(masked_logits, dim=-1)
|
||||
return top_p_sampling_from_probs(
|
||||
probs,
|
||||
top_p,
|
||||
indices,
|
||||
deterministic,
|
||||
check_nan=check_nan,
|
||||
generator=generator,
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
indices,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
@@ -5,6 +5,54 @@ import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="top_k_first",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="joint", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="joint",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
|
||||
Reference in New Issue
Block a user